From 07cbfdbedec2804f2014c42049f23268b0dc2ec7 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Mon, 11 May 2026 14:31:38 +0200 Subject: [PATCH 01/17] [proxy] feature: bring your own proxy (#5627) --- .../reverseproxy/domain/manager/manager.go | 22 +- .../domain/manager/manager_test.go | 110 +++++ .../modules/reverseproxy/proxy/manager.go | 8 +- .../reverseproxy/proxy/manager/manager.go | 64 ++- .../proxy/manager/manager_test.go | 337 +++++++++++++++ .../reverseproxy/proxy/manager_mock.go | 79 +++- .../modules/reverseproxy/proxy/proxy.go | 12 +- .../reverseproxy/proxytoken/handler.go | 195 +++++++++ .../reverseproxy/proxytoken/handler_test.go | 275 ++++++++++++ .../modules/reverseproxy/service/interface.go | 2 + .../reverseproxy/service/interface_mock.go | 29 ++ .../reverseproxy/service/manager/api.go | 24 + .../reverseproxy/service/manager/manager.go | 20 +- .../service/manager/manager_test.go | 6 +- management/internals/server/boot.go | 2 +- management/internals/shared/grpc/proxy.go | 192 ++++++-- .../shared/grpc/proxy_address_test.go | 29 ++ .../internals/shared/grpc/proxy_auth.go | 3 - .../shared/grpc/proxy_group_access_test.go | 18 + .../internals/shared/grpc/proxy_test.go | 55 +++ .../shared/grpc/validate_session_test.go | 28 +- management/server/account_test.go | 2 +- management/server/http/handler.go | 4 + .../proxy/auth_callback_integration_test.go | 9 + .../testing/testing_tools/channel/channel.go | 4 +- management/server/store/sql_store.go | 121 +++++- management/server/store/store.go | 13 +- management/server/store/store_mock.go | 157 ++++++- proxy/management_byop_integration_test.go | 409 ++++++++++++++++++ proxy/management_integration_test.go | 34 +- shared/management/http/api/openapi.yml | 165 +++++++ shared/management/http/api/types.gen.go | 41 ++ 32 files changed, 2352 insertions(+), 117 deletions(-) create mode 100644 management/internals/modules/reverseproxy/domain/manager/manager_test.go create mode 100644 management/internals/modules/reverseproxy/proxy/manager/manager_test.go create mode 100644 management/internals/modules/reverseproxy/proxytoken/handler.go create mode 100644 management/internals/modules/reverseproxy/proxytoken/handler_test.go create mode 100644 management/internals/shared/grpc/proxy_address_test.go create mode 100644 proxy/management_byop_integration_test.go diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index 2c4c1372e..ab899e0bf 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -31,6 +31,7 @@ type store interface { type proxyManager interface { GetActiveClusterAddresses(ctx context.Context) ([]string, error) + GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool @@ -71,8 +72,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d var ret []*domain.Domain // Add connected proxy clusters as free domains. - // The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io"). - allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + // For BYOP accounts, only their own cluster is returned; otherwise shared clusters. + allowList, err := m.getClusterAllowList(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err) return nil, err @@ -126,8 +127,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName return nil, status.NewPermissionDeniedError() } - // Verify the target cluster is in the available clusters - allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + // Verify the target cluster is in the available clusters for this account + allowList, err := m.getClusterAllowList(ctx, accountID) if err != nil { return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err) } @@ -273,7 +274,7 @@ func (m Manager) GetClusterDomains() []string { // For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain. // For custom domains, the cluster is determined by checking the registered custom domain's target cluster. func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) { - allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + allowList, err := m.getClusterAllowList(ctx, accountID) if err != nil { return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err) } @@ -298,6 +299,17 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain) } +func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) { + byopAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID) + if err != nil { + return nil, fmt.Errorf("get BYOP cluster addresses: %w", err) + } + if len(byopAddresses) > 0 { + return byopAddresses, nil + } + return m.proxyManager.GetActiveClusterAddresses(ctx) +} + func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) { bestCluster := "" bestLen := -1 diff --git a/management/internals/modules/reverseproxy/domain/manager/manager_test.go b/management/internals/modules/reverseproxy/domain/manager/manager_test.go new file mode 100644 index 000000000..fdeb0765f --- /dev/null +++ b/management/internals/modules/reverseproxy/domain/manager/manager_test.go @@ -0,0 +1,110 @@ +package manager + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockProxyManager struct { + getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error) + getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error) +} + +func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { + if m.getActiveClusterAddressesFunc != nil { + return m.getActiveClusterAddressesFunc(ctx) + } + return nil, nil +} + +func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + if m.getActiveClusterAddressesForAccountFunc != nil { + return m.getActiveClusterAddressesForAccountFunc(ctx, accountID) + } + return nil, nil +} + +func (m *mockProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool { + return nil +} + +func (m *mockProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool { + return nil +} + +func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool { + return nil +} + +func TestGetClusterAllowList_BYOPProxy(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) { + assert.Equal(t, "acc-123", accID) + return []string{"byop.example.com"}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist") + return nil, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"byop.example.com"}, result) +} + +func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return nil, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result) +} + +func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return nil, errors.New("db error") + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails") + return nil, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "BYOP cluster addresses") +} + +func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return []string{}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return []string{"eu.proxy.netbird.io"}, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"eu.proxy.netbird.io"}, result) +} + diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go index 53c52b3aa..07ea6f0ab 100644 --- a/management/internals/modules/reverseproxy/proxy/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -11,15 +11,19 @@ import ( // Manager defines the interface for proxy operations type Manager interface { - Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) + Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) Disconnect(ctx context.Context, proxyID, sessionID string) error Heartbeat(ctx context.Context, p *Proxy) error GetActiveClusterAddresses(ctx context.Context) ([]string, error) - GetActiveClusters(ctx context.Context) ([]Cluster, error) + GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool CleanupStale(ctx context.Context, inactivityDuration time.Duration) error + GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) + CountAccountProxies(ctx context.Context, accountID string) (int64, error) + IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) + DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error } // OIDCValidationConfig contains the OIDC configuration needed for token validation. diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager.go b/management/internals/modules/reverseproxy/proxy/manager/manager.go index 341e8c943..b72a6ebe5 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager.go @@ -16,11 +16,16 @@ type store interface { DisconnectProxy(ctx context.Context, proxyID, sessionID string) error UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) - GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) + GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) + GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error + GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) + CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) + IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) + DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error } // Manager handles all proxy operations @@ -44,7 +49,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) { // Connect registers a new proxy connection in the database. // capabilities may be nil for old proxies that do not report them. -func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) { +func (m *Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) { now := time.Now() var caps proxy.Capabilities if capabilities != nil { @@ -55,9 +60,10 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress SessionID: sessionID, ClusterAddress: clusterAddress, IPAddress: ipAddress, + AccountID: accountID, LastSeen: now, ConnectedAt: &now, - Status: "connected", + Status: proxy.StatusConnected, Capabilities: caps, } @@ -77,7 +83,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress } // Disconnect marks a proxy as disconnected in the database. -func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error { +func (m *Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error { if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil { log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err) return err @@ -92,7 +98,7 @@ func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) erro } // Heartbeat updates the proxy's last seen timestamp. -func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error { +func (m *Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error { if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil { log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err) return err @@ -104,7 +110,7 @@ func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error { } // GetActiveClusterAddresses returns all unique cluster addresses for active proxies -func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { +func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { addresses, err := m.store.GetActiveProxyClusterAddresses(ctx) if err != nil { log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err) @@ -113,16 +119,6 @@ func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error return addresses, nil } -// GetActiveClusters returns all active proxy clusters with their connected proxy count. -func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) { - clusters, err := m.store.GetActiveProxyClusters(ctx) - if err != nil { - log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err) - return nil, err - } - return clusters, nil -} - // ClusterSupportsCustomPorts returns whether any active proxy in the cluster // supports custom ports. Returns nil when no proxy has reported capabilities. func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { @@ -142,10 +138,44 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string } // CleanupStale removes proxies that haven't sent heartbeat in the specified duration -func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { +func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil { log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err) return err } return nil } + +func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err) + return nil, err + } + return addresses, nil +} + +func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) { + return m.store.GetProxyByAccountID(ctx, accountID) +} + +func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) { + return m.store.CountProxiesByAccountID(ctx, accountID) +} + +func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) { + conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID) + if err != nil { + return false, err + } + return !conflicting, nil +} + +func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { + if err := m.store.DeleteAccountCluster(ctx, clusterAddress, accountID); err != nil { + log.WithContext(ctx).Errorf("failed to delete cluster %s for account %s: %v", clusterAddress, accountID, err) + return err + } + return nil +} + diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager_test.go b/management/internals/modules/reverseproxy/proxy/manager/manager_test.go new file mode 100644 index 000000000..3c53fe684 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxy/manager/manager_test.go @@ -0,0 +1,337 @@ +package manager + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" +) + +type mockStore struct { + saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error + disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error + updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error + getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error) + getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error) + cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error + getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error) + countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error) + isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error) + deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error +} + +func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { + if m.saveProxyFunc != nil { + return m.saveProxyFunc(ctx, p) + } + return nil +} +func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error { + if m.disconnectProxyFunc != nil { + return m.disconnectProxyFunc(ctx, proxyID, sessionID) + } + return nil +} +func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error { + if m.updateProxyHeartbeatFunc != nil { + return m.updateProxyHeartbeatFunc(ctx, p) + } + return nil +} +func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) { + if m.getActiveProxyClusterAddressesFunc != nil { + return m.getActiveProxyClusterAddressesFunc(ctx) + } + return nil, nil +} +func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + if m.getActiveProxyClusterAddressesForAccFunc != nil { + return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID) + } + return nil, nil +} +func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) { + return nil, nil +} +func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error { + if m.cleanupStaleProxiesFunc != nil { + return m.cleanupStaleProxiesFunc(ctx, d) + } + return nil +} +func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) { + if m.getProxyByAccountIDFunc != nil { + return m.getProxyByAccountIDFunc(ctx, accountID) + } + return nil, fmt.Errorf("proxy not found for account %s", accountID) +} +func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) { + if m.countProxiesByAccountIDFunc != nil { + return m.countProxiesByAccountIDFunc(ctx, accountID) + } + return 0, nil +} +func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) { + if m.isClusterAddressConflictingFunc != nil { + return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID) + } + return false, nil +} +func (m *mockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { + if m.deleteAccountClusterFunc != nil { + return m.deleteAccountClusterFunc(ctx, clusterAddress, accountID) + } + return nil +} +func (m *mockStore) GetClusterSupportsCustomPorts(_ context.Context, _ string) *bool { + return nil +} +func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *bool { + return nil +} +func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool { + return nil +} + +func newTestManager(s store) *Manager { + meter := noop.NewMeterProvider().Meter("test") + m, err := NewManager(s, meter) + if err != nil { + panic(err) + } + return m +} + +func TestConnect_WithAccountID(t *testing.T) { + accountID := "acc-123" + + var savedProxy *proxy.Proxy + s := &mockStore{ + saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error { + savedProxy = p + return nil + }, + } + + mgr := newTestManager(s) + _, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", &accountID, nil) + require.NoError(t, err) + + require.NotNil(t, savedProxy) + assert.Equal(t, "proxy-1", savedProxy.ID) + assert.Equal(t, "session-1", savedProxy.SessionID) + assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress) + assert.Equal(t, "10.0.0.1", savedProxy.IPAddress) + assert.Equal(t, &accountID, savedProxy.AccountID) + assert.Equal(t, proxy.StatusConnected, savedProxy.Status) + assert.NotNil(t, savedProxy.ConnectedAt) +} + +func TestConnect_WithoutAccountID(t *testing.T) { + var savedProxy *proxy.Proxy + s := &mockStore{ + saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error { + savedProxy = p + return nil + }, + } + + mgr := newTestManager(s) + _, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "eu.proxy.netbird.io", "10.0.0.1", nil, nil) + require.NoError(t, err) + + require.NotNil(t, savedProxy) + assert.Nil(t, savedProxy.AccountID) + assert.Equal(t, proxy.StatusConnected, savedProxy.Status) +} + +func TestConnect_StoreError(t *testing.T) { + s := &mockStore{ + saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error { + return errors.New("db error") + }, + } + + mgr := newTestManager(s) + _, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", nil, nil) + assert.Error(t, err) +} + +func TestIsClusterAddressAvailable(t *testing.T) { + tests := []struct { + name string + conflicting bool + storeErr error + wantResult bool + wantErr bool + }{ + { + name: "available - no conflict", + conflicting: false, + wantResult: true, + }, + { + name: "not available - conflict exists", + conflicting: true, + wantResult: false, + }, + { + name: "store error", + storeErr: errors.New("db error"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &mockStore{ + isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) { + return tt.conflicting, tt.storeErr + }, + } + + mgr := newTestManager(s) + result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123") + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantResult, result) + }) + } +} + +func TestCountAccountProxies(t *testing.T) { + tests := []struct { + name string + count int64 + storeErr error + wantCount int64 + wantErr bool + }{ + { + name: "no proxies", + count: 0, + wantCount: 0, + }, + { + name: "one proxy", + count: 1, + wantCount: 1, + }, + { + name: "store error", + storeErr: errors.New("db error"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &mockStore{ + countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) { + return tt.count, tt.storeErr + }, + } + + mgr := newTestManager(s) + count, err := mgr.CountAccountProxies(context.Background(), "acc-123") + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantCount, count) + }) + } +} + +func TestGetAccountProxy(t *testing.T) { + accountID := "acc-123" + + t.Run("found", func(t *testing.T) { + expected := &proxy.Proxy{ + ID: "proxy-1", + ClusterAddress: "byop.example.com", + AccountID: &accountID, + Status: proxy.StatusConnected, + } + s := &mockStore{ + getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) { + assert.Equal(t, accountID, accID) + return expected, nil + }, + } + + mgr := newTestManager(s) + p, err := mgr.GetAccountProxy(context.Background(), accountID) + require.NoError(t, err) + assert.Equal(t, expected, p) + }) + + t.Run("not found", func(t *testing.T) { + s := &mockStore{ + getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) { + return nil, errors.New("not found") + }, + } + + mgr := newTestManager(s) + _, err := mgr.GetAccountProxy(context.Background(), accountID) + assert.Error(t, err) + }) +} + +func TestDeleteAccountCluster(t *testing.T) { + t.Run("success", func(t *testing.T) { + var deletedCluster, deletedAccount string + s := &mockStore{ + deleteAccountClusterFunc: func(_ context.Context, clusterAddress, accountID string) error { + deletedCluster = clusterAddress + deletedAccount = accountID + return nil + }, + } + + mgr := newTestManager(s) + err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123") + require.NoError(t, err) + assert.Equal(t, "cluster.example.com", deletedCluster) + assert.Equal(t, "acc-123", deletedAccount) + }) + + t.Run("store error", func(t *testing.T) { + s := &mockStore{ + deleteAccountClusterFunc: func(_ context.Context, _, _ string) error { + return errors.New("db error") + }, + } + + mgr := newTestManager(s) + err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123") + assert.Error(t, err) + }) +} + +func TestGetActiveClusterAddressesForAccount(t *testing.T) { + expected := []string{"byop.example.com"} + s := &mockStore{ + getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) { + assert.Equal(t, "acc-123", accID) + return expected, nil + }, + } + + mgr := newTestManager(s) + result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, expected, result) +} diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go index 98d97b3c6..a0e360a1b 100644 --- a/management/internals/modules/reverseproxy/proxy/manager_mock.go +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -93,18 +93,18 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte } // Connect mocks base method. -func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) { +func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities) + ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities) ret0, _ := ret[0].(*Proxy) ret1, _ := ret[1].(error) return ret0, ret1 } // Connect indicates an expected call of Connect. -func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities) } // Disconnect mocks base method. @@ -136,19 +136,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx) } -// GetActiveClusters mocks base method. -func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) { +func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetActiveClusters", ctx) - ret0, _ := ret[0].([]Cluster) + ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID) + ret0, _ := ret[0].([]string) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetActiveClusters indicates an expected call of GetActiveClusters. -func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID) } // Heartbeat mocks base method. @@ -165,6 +163,65 @@ func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p) } +// GetAccountProxy mocks base method. +func (m *MockManager) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccountProxy", ctx, accountID) + ret0, _ := ret[0].(*Proxy) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccountProxy indicates an expected call of GetAccountProxy. +func (mr *MockManagerMockRecorder) GetAccountProxy(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountProxy", reflect.TypeOf((*MockManager)(nil).GetAccountProxy), ctx, accountID) +} + +// CountAccountProxies mocks base method. +func (m *MockManager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountAccountProxies", ctx, accountID) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountAccountProxies indicates an expected call of CountAccountProxies. +func (mr *MockManagerMockRecorder) CountAccountProxies(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountProxies", reflect.TypeOf((*MockManager)(nil).CountAccountProxies), ctx, accountID) +} + +// IsClusterAddressAvailable mocks base method. +func (m *MockManager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsClusterAddressAvailable", ctx, clusterAddress, accountID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsClusterAddressAvailable indicates an expected call of IsClusterAddressAvailable. +func (mr *MockManagerMockRecorder) IsClusterAddressAvailable(ctx, clusterAddress, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressAvailable", reflect.TypeOf((*MockManager)(nil).IsClusterAddressAvailable), ctx, clusterAddress, accountID) +} + +// DeleteAccountCluster mocks base method. +func (m *MockManager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAccountCluster indicates an expected call of DeleteAccountCluster. +func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID) +} + // MockController is a mock of Controller interface. type MockController struct { ctrl *gomock.Controller diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go index dcedb8811..64394799e 100644 --- a/management/internals/modules/reverseproxy/proxy/proxy.go +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -1,6 +1,13 @@ package proxy -import "time" +import ( + "time" +) + +const ( + StatusConnected = "connected" + StatusDisconnected = "disconnected" +) // Capabilities describes what a proxy can handle, as reported via gRPC. // Nil fields mean the proxy never reported this capability. @@ -21,6 +28,7 @@ type Proxy struct { SessionID string `gorm:"type:varchar(36)"` ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"` IPAddress string `gorm:"type:varchar(45)"` + AccountID *string `gorm:"type:varchar(255);index:idx_proxy_account_id"` LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"` ConnectedAt *time.Time DisconnectedAt *time.Time @@ -36,6 +44,8 @@ func (Proxy) TableName() string { // Cluster represents a group of proxy nodes serving the same address. type Cluster struct { + ID string Address string ConnectedProxies int + SelfHosted bool } diff --git a/management/internals/modules/reverseproxy/proxytoken/handler.go b/management/internals/modules/reverseproxy/proxytoken/handler.go new file mode 100644 index 000000000..728cdf723 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxytoken/handler.go @@ -0,0 +1,195 @@ +package proxytoken + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/gorilla/mux" + + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" +) + +type handler struct { + store store.Store + permissionsManager permissions.Manager +} + +func RegisterEndpoints(s store.Store, permissionsManager permissions.Manager, router *mux.Router) { + h := &handler{store: s, permissionsManager: permissionsManager} + router.HandleFunc("/reverse-proxies/proxy-tokens", h.listTokens).Methods("GET", "OPTIONS") + router.HandleFunc("/reverse-proxies/proxy-tokens", h.createToken).Methods("POST", "OPTIONS") + router.HandleFunc("/reverse-proxies/proxy-tokens/{tokenId}", h.revokeToken).Methods("DELETE", "OPTIONS") +} + +func (h *handler) createToken(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create) + if err != nil { + util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w) + return + } + if !ok { + util.WriteErrorResponse("permission denied", http.StatusForbidden, w) + return + } + + var req api.ProxyTokenRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + if req.Name == "" || len(req.Name) > 255 { + util.WriteErrorResponse("name is required and must be at most 255 characters", http.StatusBadRequest, w) + return + } + + var expiresIn time.Duration + if req.ExpiresIn != nil { + if *req.ExpiresIn < 0 { + util.WriteErrorResponse("expires_in must be non-negative", http.StatusBadRequest, w) + return + } + if *req.ExpiresIn > 0 { + expiresIn = time.Duration(*req.ExpiresIn) * time.Second + } + } + + accountID := userAuth.AccountId + generated, err := types.CreateNewProxyAccessToken(req.Name, expiresIn, &accountID, userAuth.UserId) + if err != nil { + util.WriteErrorResponse("failed to generate token", http.StatusInternalServerError, w) + return + } + + if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil { + util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w) + return + } + + resp := toProxyTokenCreatedResponse(generated) + util.WriteJSONObject(r.Context(), w, resp) +} + +func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read) + if err != nil { + util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w) + return + } + if !ok { + util.WriteErrorResponse("permission denied", http.StatusForbidden, w) + return + } + + tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId) + if err != nil { + util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w) + return + } + + resp := make([]api.ProxyToken, 0, len(tokens)) + for _, token := range tokens { + resp = append(resp, toProxyTokenResponse(token)) + } + + util.WriteJSONObject(r.Context(), w, resp) +} + +func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete) + if err != nil { + util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w) + return + } + if !ok { + util.WriteErrorResponse("permission denied", http.StatusForbidden, w) + return + } + + tokenID := mux.Vars(r)["tokenId"] + if tokenID == "" { + util.WriteErrorResponse("token ID is required", http.StatusBadRequest, w) + return + } + + token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID) + if err != nil { + if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound { + util.WriteErrorResponse("token not found", http.StatusNotFound, w) + } else { + util.WriteErrorResponse("failed to retrieve token", http.StatusInternalServerError, w) + } + return + } + + if token.AccountID == nil || *token.AccountID != userAuth.AccountId { + util.WriteErrorResponse("token not found", http.StatusNotFound, w) + return + } + + if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil { + util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} + +func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken { + resp := api.ProxyToken{ + Id: token.ID, + Name: token.Name, + Revoked: token.Revoked, + } + if !token.CreatedAt.IsZero() { + resp.CreatedAt = token.CreatedAt + } + if token.ExpiresAt != nil { + resp.ExpiresAt = token.ExpiresAt + } + if token.LastUsed != nil { + resp.LastUsed = token.LastUsed + } + return resp +} + +func toProxyTokenCreatedResponse(generated *types.ProxyAccessTokenGenerated) api.ProxyTokenCreated { + base := toProxyTokenResponse(&generated.ProxyAccessToken) + plainToken := string(generated.PlainToken) + return api.ProxyTokenCreated{ + Id: base.Id, + Name: base.Name, + CreatedAt: base.CreatedAt, + ExpiresAt: base.ExpiresAt, + LastUsed: base.LastUsed, + Revoked: base.Revoked, + PlainToken: plainToken, + } +} diff --git a/management/internals/modules/reverseproxy/proxytoken/handler_test.go b/management/internals/modules/reverseproxy/proxytoken/handler_test.go new file mode 100644 index 000000000..a28752909 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxytoken/handler_test.go @@ -0,0 +1,275 @@ +package proxytoken + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func authContext(accountID, userID string) context.Context { + return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{ + AccountId: accountID, + UserId: userID, + }) +} + +func TestCreateToken_AccountScoped(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + accountID := "acc-123" + var savedToken *types.ProxyAccessToken + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, token *types.ProxyAccessToken) error { + savedToken = token + return nil + }, + ) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + body := `{"name": "my-token"}` + req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body)) + req = req.WithContext(authContext(accountID, "user-1")) + w := httptest.NewRecorder() + + h.createToken(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + var resp api.ProxyTokenCreated + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + + assert.NotEmpty(t, resp.PlainToken) + assert.Equal(t, "my-token", resp.Name) + assert.False(t, resp.Revoked) + + require.NotNil(t, savedToken) + require.NotNil(t, savedToken.AccountID) + assert.Equal(t, accountID, *savedToken.AccountID) + assert.Equal(t, "user-1", savedToken.CreatedBy) +} + +func TestCreateToken_WithExpiration(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var savedToken *types.ProxyAccessToken + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, token *types.ProxyAccessToken) error { + savedToken = token + return nil + }, + ) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + body := `{"name": "expiring-token", "expires_in": 3600}` + req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body)) + req = req.WithContext(authContext("acc-123", "user-1")) + w := httptest.NewRecorder() + + h.createToken(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + require.NotNil(t, savedToken) + require.NotNil(t, savedToken.ExpiresAt) + assert.True(t, savedToken.ExpiresAt.After(time.Now())) +} + +func TestCreateToken_EmptyName(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil) + + h := &handler{ + permissionsManager: permsMgr, + } + + body := `{"name": ""}` + req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body)) + req = req.WithContext(authContext("acc-123", "user-1")) + w := httptest.NewRecorder() + + h.createToken(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestCreateToken_PermissionDenied(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil) + + h := &handler{ + permissionsManager: permsMgr, + } + + body := `{"name": "test"}` + req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body)) + req = req.WithContext(authContext("acc-123", "user-1")) + w := httptest.NewRecorder() + + h.createToken(w, req) + assert.Equal(t, http.StatusForbidden, w.Code) +} + +func TestListTokens(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + accountID := "acc-123" + now := time.Now() + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetProxyAccessTokensByAccountID(gomock.Any(), store.LockingStrengthNone, accountID).Return([]*types.ProxyAccessToken{ + {ID: "tok-1", Name: "token-1", AccountID: &accountID, CreatedAt: now, Revoked: false}, + {ID: "tok-2", Name: "token-2", AccountID: &accountID, CreatedAt: now, Revoked: true}, + }, nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("GET", "/reverse-proxies/proxy-tokens", nil) + req = req.WithContext(authContext(accountID, "user-1")) + w := httptest.NewRecorder() + + h.listTokens(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + var resp []api.ProxyToken + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + require.Len(t, resp, 2) + assert.Equal(t, "tok-1", resp[0].Id) + assert.False(t, resp[0].Revoked) + assert.Equal(t, "tok-2", resp[1].Id) + assert.True(t, resp[1].Revoked) +} + +func TestRevokeToken_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + accountID := "acc-123" + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{ + ID: "tok-1", + Name: "test-token", + AccountID: &accountID, + }, nil) + mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil) + req = req.WithContext(authContext(accountID, "user-1")) + req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"}) + w := httptest.NewRecorder() + + h.revokeToken(w, req) + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestRevokeToken_WrongAccount(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + otherAccount := "acc-other" + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{ + ID: "tok-1", + AccountID: &otherAccount, + }, nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil) + req = req.WithContext(authContext("acc-123", "user-1")) + req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"}) + w := httptest.NewRecorder() + + h.revokeToken(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) +} + +func TestRevokeToken_ManagementWideToken(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{ + ID: "tok-1", + AccountID: nil, + }, nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil) + req = req.WithContext(authContext("acc-123", "user-1")) + req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"}) + w := httptest.NewRecorder() + + h.revokeToken(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) +} diff --git a/management/internals/modules/reverseproxy/service/interface.go b/management/internals/modules/reverseproxy/service/interface.go index a49cbea35..6a94aa32b 100644 --- a/management/internals/modules/reverseproxy/service/interface.go +++ b/management/internals/modules/reverseproxy/service/interface.go @@ -10,6 +10,7 @@ import ( type Manager interface { GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) + DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error) CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) @@ -28,4 +29,5 @@ type Manager interface { RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error StartExposeReaper(ctx context.Context) + GetServiceByDomain(ctx context.Context, domain string) (*Service, error) } diff --git a/management/internals/modules/reverseproxy/service/interface_mock.go b/management/internals/modules/reverseproxy/service/interface_mock.go index cc5ccbb8e..83b2162ed 100644 --- a/management/internals/modules/reverseproxy/service/interface_mock.go +++ b/management/internals/modules/reverseproxy/service/interface_mock.go @@ -79,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID) } +// DeleteAccountCluster mocks base method. +func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, accountID, userID, clusterAddress) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAccountCluster indicates an expected call of DeleteAccountCluster. +func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID, clusterAddress interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress) +} + // DeleteService mocks base method. func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { m.ctrl.T.Helper() @@ -138,6 +152,21 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID) } +// GetServiceByDomain mocks base method. +func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain) + ret0, _ := ret[0].(*Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServiceByDomain indicates an expected call of GetServiceByDomain. +func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain) +} + // GetGlobalServices mocks base method. func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/service/manager/api.go b/management/internals/modules/reverseproxy/service/manager/api.go index cd81efa88..08272077c 100644 --- a/management/internals/modules/reverseproxy/service/manager/api.go +++ b/management/internals/modules/reverseproxy/service/manager/api.go @@ -35,6 +35,7 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma accesslogsmanager.RegisterEndpoints(router, accessLogsManager) router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS") + router.HandleFunc("/reverse-proxies/clusters/{clusterAddress}", h.deleteCluster).Methods("DELETE", "OPTIONS") router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS") router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS") @@ -195,10 +196,33 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) { apiClusters := make([]api.ProxyCluster, 0, len(clusters)) for _, c := range clusters { apiClusters = append(apiClusters, api.ProxyCluster{ + Id: c.ID, Address: c.Address, ConnectedProxies: c.ConnectedProxies, + SelfHosted: c.SelfHosted, }) } util.WriteJSONObject(r.Context(), w, apiClusters) } + +func (h *handler) deleteCluster(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + clusterAddress := mux.Vars(r)["clusterAddress"] + if clusterAddress == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "cluster address is required"), w) + return + } + + if err := h.manager.DeleteAccountCluster(r.Context(), userAuth.AccountId, userAuth.UserId, clusterAddress); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index d03a8dc82..c866d8f75 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -122,7 +122,21 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin return nil, status.NewPermissionDeniedError() } - return m.store.GetActiveProxyClusters(ctx) + return m.store.GetActiveProxyClusters(ctx, accountID) +} + +// DeleteAccountCluster removes all proxy registrations for the given cluster address +// owned by the account. +func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + return m.store.DeleteAccountCluster(ctx, clusterAddress, accountID) } func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) { @@ -986,6 +1000,10 @@ func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]* return services, nil } +func (m *Manager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + return m.store.GetServiceByDomain(ctx, domain) +} + func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) { target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID) if err != nil { diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index 46e79f1e5..47b8b3865 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { t.Helper() tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t)) pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t)) - srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) + srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil) return srv } @@ -714,7 +714,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) - proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) + proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) require.NoError(t, err) @@ -1138,7 +1138,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) { tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) - proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) + proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) require.NoError(t, err) diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index f2ab0a2c4..7c655f020 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -193,7 +193,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer { - proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager()) + proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store()) s.AfterInit(func(s *BaseServer) { proxyService.SetServiceManager(s.ServiceManager()) proxyService.SetProxyController(s.ServiceProxyController()) diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 6763a3ba3..9e5027547 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -9,6 +9,7 @@ import ( "encoding/hex" "errors" "fmt" + "net" "net/http" "net/url" "os" @@ -50,6 +51,11 @@ type ProxyOIDCConfig struct { KeysLocation string } +// ProxyTokenChecker checks whether a proxy access token is still valid. +type ProxyTokenChecker interface { + IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) +} + // ProxyServiceServer implements the ProxyService gRPC server type ProxyServiceServer struct { proto.UnimplementedProxyServiceServer @@ -78,6 +84,9 @@ type ProxyServiceServer struct { // Store for one-time authentication tokens tokenStore *OneTimeTokenStore + // Checker for proxy access token validity + tokenChecker ProxyTokenChecker + // OIDC configuration for proxy authentication oidcConfig ProxyOIDCConfig @@ -123,6 +132,8 @@ type proxyConnection struct { proxyID string sessionID string address string + accountID *string + tokenID string capabilities *proto.ProxyCapabilities stream proto.ProxyService_GetMappingUpdateServer sendChan chan *proto.GetMappingUpdateResponse @@ -130,8 +141,19 @@ type proxyConnection struct { cancel context.CancelFunc } +func enforceAccountScope(ctx context.Context, requestAccountID string) error { + token := GetProxyTokenFromContext(ctx) + if token == nil || token.AccountID == nil { + return nil + } + if requestAccountID == "" || *token.AccountID != requestAccountID { + return status.Errorf(codes.PermissionDenied, "account-scoped token cannot access account %s", requestAccountID) + } + return nil +} + // NewProxyServiceServer creates a new proxy service server. -func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer { +func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer { ctx, cancel := context.WithCancel(context.Background()) s := &ProxyServiceServer{ accessLogManager: accessLogMgr, @@ -141,6 +163,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT peersManager: peersManager, usersManager: usersManager, proxyManager: proxyMgr, + tokenChecker: tokenChecker, snapshotBatchSize: snapshotBatchSizeFromEnv(), cancel: cancel, } @@ -200,6 +223,25 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest return status.Errorf(codes.InvalidArgument, "proxy address is invalid") } + var accountID *string + token := GetProxyTokenFromContext(ctx) + if token != nil && token.AccountID != nil { + accountID = token.AccountID + + available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID) + if err != nil { + return status.Errorf(codes.Internal, "check cluster address: %v", err) + } + if !available { + return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress) + } + } + + var tokenID string + if token != nil { + tokenID = token.ID + } + sessionID := uuid.NewString() if old, loaded := s.connectedProxies.Load(proxyID); loaded { @@ -217,6 +259,8 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest proxyID: proxyID, sessionID: sessionID, address: proxyAddress, + accountID: accountID, + tokenID: tokenID, capabilities: req.GetCapabilities(), stream: stream, sendChan: make(chan *proto.GetMappingUpdateResponse, 100), @@ -224,7 +268,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest cancel: cancel, } - // Register proxy in database with capabilities var caps *proxy.Capabilities if c := req.GetCapabilities(); c != nil { caps = &proxy.Capabilities{ @@ -233,10 +276,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest SupportsCrowdsec: c.SupportsCrowdsec, } } - proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps) + proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps) if err != nil { - log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) cancel() + if accountID != nil { + return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err) + } + log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) return status.Errorf(codes.Internal, "register proxy in database: %v", err) } @@ -266,6 +312,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest "session_id": sessionID, "address": proxyAddress, "cluster_addr": proxyAddress, + "account_id": accountID, "total_proxies": len(s.GetConnectedProxies()), }).Info("Proxy registered in cluster") defer func() { @@ -286,7 +333,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest log.Infof("Proxy %s session %s disconnected", proxyID, sessionID) }() - go s.heartbeat(connCtx, proxyRecord) + go s.heartbeat(connCtx, conn, proxyRecord) select { case err := <-errChan: @@ -298,8 +345,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest } } -// heartbeat updates the proxy's last_seen timestamp every minute -func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) { +// heartbeat updates the proxy's last_seen timestamp every minute and +// disconnects the proxy if its access token has been revoked. +func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() @@ -309,6 +357,19 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) { if err := s.proxyManager.Heartbeat(ctx, p); err != nil { log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err) } + + if conn.tokenID != "" && s.tokenChecker != nil { + valid, err := s.tokenChecker.IsProxyAccessTokenValid(ctx, conn.tokenID) + if err != nil { + log.WithContext(ctx).Warnf("failed to check token validity for proxy %s: %v", conn.proxyID, err) + continue + } + if !valid { + log.WithContext(ctx).Warnf("proxy %s token revoked or expired, disconnecting", conn.proxyID) + conn.cancel() + return + } + } case <-ctx.Done(): log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID) return @@ -316,8 +377,6 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) { } } -// sendSnapshot sends the initial snapshot of services to the connecting proxy. -// Only entries matching the proxy's cluster address are sent. func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { if !isProxyAddressValid(conn.address) { return fmt.Errorf("proxy address is invalid") @@ -355,7 +414,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec } func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) { - services, err := s.serviceManager.GetGlobalServices(ctx) + var services []*rpservice.Service + var err error + if conn.accountID != nil { + services, err = s.serviceManager.GetAccountServices(ctx, *conn.accountID) + } else { + services, err = s.serviceManager.GetGlobalServices(ctx) + } if err != nil { return nil, fmt.Errorf("get services from store: %w", err) } @@ -380,8 +445,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn * return mappings, nil } -// isProxyAddressValid validates a proxy address +// isProxyAddressValid validates a proxy address (domain name or IP address) func isProxyAddressValid(addr string) bool { + if addr == "" { + return false + } + if net.ParseIP(addr) != nil { + return true + } _, err := domain.ValidateDomains([]string{addr}) return err == nil } @@ -405,6 +476,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) { accessLog := req.GetLog() + if err := enforceAccountScope(ctx, accessLog.GetAccountId()); err != nil { + return nil, err + } + fields := log.Fields{ "service_id": accessLog.GetServiceId(), "account_id": accessLog.GetAccountId(), @@ -442,11 +517,32 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA // Management should call this when services are created/updated/removed. // For create/update operations a unique one-time auth token is generated per // proxy so that every replica can independently authenticate with management. +// BYOP proxies only receive updates for their own account's services. func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) { log.Debugf("Broadcasting service update to all connected proxy servers") + updateAccountIDs := make(map[string]struct{}) + for _, m := range update.Mapping { + if m.AccountId != "" { + updateAccountIDs[m.AccountId] = struct{}{} + } + } s.connectedProxies.Range(func(key, value interface{}) bool { conn := value.(*proxyConnection) - resp := s.perProxyMessage(update, conn.proxyID) + connUpdate := update + if conn.accountID != nil && len(updateAccountIDs) > 0 { + if _, ok := updateAccountIDs[*conn.accountID]; !ok { + return true + } + filtered := filterMappingsForAccount(update.Mapping, *conn.accountID) + if len(filtered) == 0 { + return true + } + connUpdate = &proto.GetMappingUpdateResponse{ + Mapping: filtered, + InitialSyncComplete: update.InitialSyncComplete, + } + } + resp := s.perProxyMessage(connUpdate, conn.proxyID) if resp == nil { log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID) conn.cancel() @@ -463,6 +559,26 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes }) } +// ForceDisconnect cancels the gRPC stream for a connected proxy, causing it to disconnect. +func (s *ProxyServiceServer) ForceDisconnect(proxyID string) { + if connVal, ok := s.connectedProxies.Load(proxyID); ok { + conn := connVal.(*proxyConnection) + conn.cancel() + s.connectedProxies.Delete(proxyID) + log.WithFields(log.Fields{"proxyID": proxyID}).Info("force disconnected proxy") + } +} + +func filterMappingsForAccount(mappings []*proto.ProxyMapping, accountID string) []*proto.ProxyMapping { + var filtered []*proto.ProxyMapping + for _, m := range mappings { + if m.AccountId == accountID { + filtered = append(filtered, m) + } + } + return filtered +} + // GetConnectedProxies returns a list of connected proxy IDs func (s *ProxyServiceServer) GetConnectedProxies() []string { var proxies []string @@ -531,6 +647,9 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd continue } conn := connVal.(*proxyConnection) + if conn.accountID != nil && update.AccountId != "" && *conn.accountID != update.AccountId { + continue + } if !proxyAcceptsMapping(conn, update) { log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id) continue @@ -618,6 +737,10 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { } func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) if err != nil { log.WithContext(ctx).Debugf("failed to get service from store: %v", err) @@ -737,6 +860,10 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic // SendStatusUpdate handles status updates from proxy clients. func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) { + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + accountID := req.GetAccountId() serviceID := req.GetServiceId() protoStatus := req.GetStatus() @@ -807,6 +934,10 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status { // CreateProxyPeer handles proxy peer creation with one-time token authentication func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) { + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + serviceID := req.GetServiceId() accountID := req.GetAccountId() token := req.GetToken() @@ -861,6 +992,10 @@ func strPtr(s string) *string { } func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) { + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + redirectURL, err := url.Parse(req.GetRedirectUrl()) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err) @@ -989,21 +1124,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL // GenerateSessionToken creates a signed session JWT for the given domain and user. func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) { - // Find the service by domain to get its signing key - services, err := s.serviceManager.GetGlobalServices(ctx) + service, err := s.getServiceByDomain(ctx, domain) if err != nil { - return "", fmt.Errorf("get services: %w", err) - } - - var service *rpservice.Service - for _, svc := range services { - if svc.Domain == domain { - service = svc - break - } - } - if service == nil { - return "", fmt.Errorf("service not found for domain: %s", domain) + return "", fmt.Errorf("service not found for domain %s: %w", domain, err) } if service.SessionPrivateKey == "" { @@ -1101,6 +1224,10 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val }, nil } + if err := enforceAccountScope(ctx, service.AccountID); err != nil { + return nil, err + } + pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey) if err != nil { log.WithFields(log.Fields{ @@ -1184,18 +1311,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val } func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) { - services, err := s.serviceManager.GetGlobalServices(ctx) - if err != nil { - return nil, fmt.Errorf("get services: %w", err) - } - - for _, service := range services { - if service.Domain == domain { - return service, nil - } - } - - return nil, fmt.Errorf("service not found for domain: %s", domain) + return s.serviceManager.GetServiceByDomain(ctx, domain) } func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error { diff --git a/management/internals/shared/grpc/proxy_address_test.go b/management/internals/shared/grpc/proxy_address_test.go new file mode 100644 index 000000000..824a57226 --- /dev/null +++ b/management/internals/shared/grpc/proxy_address_test.go @@ -0,0 +1,29 @@ +package grpc + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsProxyAddressValid(t *testing.T) { + tests := []struct { + name string + addr string + valid bool + }{ + {name: "valid domain", addr: "eu.proxy.netbird.io", valid: true}, + {name: "valid subdomain", addr: "byop.proxy.example.com", valid: true}, + {name: "valid IPv4", addr: "10.0.0.1", valid: true}, + {name: "valid IPv4 public", addr: "203.0.113.10", valid: true}, + {name: "valid IPv6", addr: "::1", valid: true}, + {name: "valid IPv6 full", addr: "2001:db8::1", valid: true}, + {name: "empty string", addr: "", valid: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.valid, isProxyAddressValid(tt.addr)) + }) + } +} diff --git a/management/internals/shared/grpc/proxy_auth.go b/management/internals/shared/grpc/proxy_auth.go index dd593dfa0..9888e8eee 100644 --- a/management/internals/shared/grpc/proxy_auth.go +++ b/management/internals/shared/grpc/proxy_auth.go @@ -153,9 +153,6 @@ func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types return nil, status.Errorf(codes.Unauthenticated, "invalid token") } - // TODO: Enforce AccountID scope for "bring your own proxy" feature. - // Currently tokens are management-wide; AccountID field is reserved for future use. - if !token.IsValid() { return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked") } diff --git a/management/internals/shared/grpc/proxy_group_access_test.go b/management/internals/shared/grpc/proxy_group_access_test.go index 0fa9a0dc1..46dad5b56 100644 --- a/management/internals/shared/grpc/proxy_group_access_test.go +++ b/management/internals/shared/grpc/proxy_group_access_test.go @@ -53,6 +53,10 @@ func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, return nil } +func (m *mockReverseProxyManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error { + return nil +} + func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error { return nil } @@ -91,6 +95,20 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {} +func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain string) (*service.Service, error) { + if m.err != nil { + return nil, m.err + } + for _, services := range m.proxiesByAccount { + for _, svc := range services { + if svc.Domain == domain { + return svc, nil + } + } + } + return nil, errors.New("service not found for domain: " + domain) +} + func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { return nil, nil } diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index 5a7a457df..0379edc6d 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -12,9 +12,12 @@ import ( cachestore "github.com/eko/gocache/lib/v4/store" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + grpcstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" nbcache "github.com/netbirdio/netbird/management/server/cache" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -316,6 +319,58 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { assert.Contains(t, err.Error(), "invalid state format") } +func scopedCtx(accountID string) context.Context { + token := &types.ProxyAccessToken{ + ID: "token-1", + AccountID: &accountID, + } + return context.WithValue(context.Background(), ProxyTokenContextKey, token) +} + +func globalCtx() context.Context { + token := &types.ProxyAccessToken{ + ID: "token-global", + } + return context.WithValue(context.Background(), ProxyTokenContextKey, token) +} + +func TestEnforceAccountScope_AllowsMatchingAccount(t *testing.T) { + err := enforceAccountScope(scopedCtx("acc-1"), "acc-1") + assert.NoError(t, err) +} + +func TestEnforceAccountScope_BlocksMismatchedAccount(t *testing.T) { + err := enforceAccountScope(scopedCtx("acc-1"), "acc-2") + require.Error(t, err) + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.PermissionDenied, st.Code()) +} + +func TestEnforceAccountScope_BlocksEmptyRequestAccountID(t *testing.T) { + err := enforceAccountScope(scopedCtx("acc-1"), "") + require.Error(t, err) + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.PermissionDenied, st.Code()) +} + +func TestEnforceAccountScope_AllowsGlobalToken(t *testing.T) { + err := enforceAccountScope(globalCtx(), "acc-1") + assert.NoError(t, err) + + err = enforceAccountScope(globalCtx(), "acc-2") + assert.NoError(t, err) + + err = enforceAccountScope(globalCtx(), "") + assert.NoError(t, err) +} + +func TestEnforceAccountScope_AllowsNoTokenInContext(t *testing.T) { + err := enforceAccountScope(context.Background(), "acc-1") + assert.NoError(t, err) +} + func TestValidateState_RejectsInvalidHMAC(t *testing.T) { ctx := context.Background() pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index d1d7fc8b7..6cd95f988 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup { tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) - proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager) + proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil) proxyService.SetServiceManager(serviceManager) createTestProxies(t, ctx, testStore) @@ -318,13 +318,17 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {} +func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + return m.store.GetServiceByDomain(ctx, domain) +} + func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { return nil, nil } type testValidateSessionProxyManager struct{} -func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *proxy.Capabilities) error { +func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error { return nil } @@ -340,6 +344,10 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co return nil, nil } +func (m *testValidateSessionProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) { + return nil, nil +} + func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) { return nil, nil } @@ -348,6 +356,22 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time return nil } +func (m *testValidateSessionProxyManager) GetAccountProxy(_ context.Context, _ string) (*proxy.Proxy, error) { + return nil, nil +} + +func (m *testValidateSessionProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) { + return 0, nil +} + +func (m *testValidateSessionProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) { + return true, nil +} + +func (m *testValidateSessionProxyManager) DeleteProxy(_ context.Context, _ string) error { + return nil +} + func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool { return nil } diff --git a/management/server/account_test.go b/management/server/account_test.go index 6bb875f99..65b27df49 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3113,7 +3113,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU return nil, nil, err } - proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager) + proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil) proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{}) if err != nil { return nil, nil, err diff --git a/management/server/http/handler.go b/management/server/http/handler.go index b9ea605d3..1e2c710db 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken" reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" @@ -144,6 +145,9 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks if serviceManager != nil && reverseProxyDomainManager != nil { reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router) } + + proxytoken.RegisterEndpoints(accountManager.GetStore(), permissionsManager, router) + // Register OAuth callback handler for proxy authentication if proxyGRPCServer != nil { oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies) diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index c99acab63..30d8aa0e7 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -216,6 +216,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { nil, usersManager, nil, + nil, ) proxyService.SetServiceManager(&testServiceManager{store: testStore}) @@ -389,6 +390,10 @@ func (m *testServiceManager) DeleteService(_ context.Context, _, _, _ string) er return nil } +func (m *testServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error { + return nil +} + func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error { return nil } @@ -435,6 +440,10 @@ func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ stri func (m *testServiceManager) StartExposeReaper(_ context.Context) {} +func (m *testServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + return m.store.GetServiceByDomain(ctx, domain) +} + func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { return nil, nil } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 1a8b83c7e..3c4ea98d0 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -109,7 +109,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee if err != nil { t.Fatalf("Failed to create proxy manager: %v", err) } - proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) + proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil) domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) if err != nil { @@ -238,7 +238,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin if err != nil { t.Fatalf("Failed to create proxy manager: %v", err) } - proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) + proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil) domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) if err != nil { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 065a0d306..4c2f0be52 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -4513,6 +4513,47 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e return nil } +func (s *SqlStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var tokens []*types.ProxyAccessToken + result := tx.Where("account_id = ?", accountID).Find(&tokens) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "get proxy access tokens by account: %v", result.Error) + } + + return tokens, nil +} + +func (s *SqlStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) { + token, err := s.GetProxyAccessTokenByID(ctx, LockingStrengthNone, tokenID) + if err != nil { + return false, err + } + return token.IsValid(), nil +} + +func (s *SqlStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var token types.ProxyAccessToken + result := tx.Take(&token, idQueryCondition, tokenID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "proxy access token not found") + } + return nil, status.Errorf(status.Internal, "get proxy access token by ID: %v", result.Error) + } + + return &token, nil +} + // MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token. func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error { result := s.db.Model(&types.ProxyAccessToken{}). @@ -5487,7 +5528,7 @@ func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID strin Model(&proxy.Proxy{}). Where("id = ? AND session_id = ?", proxyID, sessionID). Updates(map[string]any{ - "status": "disconnected", + "status": proxy.StatusDisconnected, "disconnected_at": now, "last_seen": now, }) @@ -5518,7 +5559,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) err if result.RowsAffected == 0 { p.LastSeen = now p.ConnectedAt = &now - p.Status = "connected" + p.Status = proxy.StatusConnected if err := s.db.Create(p).Error; err != nil { log.WithContext(ctx).Debugf("proxy %s session %s: heartbeat fallback insert skipped: %v", p.ID, p.SessionID, err) } @@ -5527,13 +5568,15 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) err return nil } -// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies +// GetActiveProxyClusterAddresses returns the unique cluster addresses of active +// shared proxies (those without an account scope). BYOP cluster addresses are +// excluded; use GetActiveProxyClusterAddressesForAccount to retrieve them. func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) { var addresses []string result := s.db. Model(&proxy.Proxy{}). - Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). + Where("account_id IS NULL AND status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)). Distinct("cluster_address"). Pluck("cluster_address", &addresses) @@ -5545,13 +5588,75 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string return addresses, nil } -// GetActiveProxyClusters returns all active proxy clusters with their connected proxy count. -func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) { +func (s *SqlStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + var addresses []string + + result := s.db. + Model(&proxy.Proxy{}). + Where("account_id = ? AND status = ? AND last_seen > ?", accountID, proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)). + Distinct("cluster_address"). + Pluck("cluster_address", &addresses) + + if result.Error != nil { + return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses for account") + } + + return addresses, nil +} + +func (s *SqlStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) { + var p proxy.Proxy + result := s.db.Where("account_id = ?", accountID).Take(&p) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "proxy not found for account") + } + return nil, status.Errorf(status.Internal, "get proxy by account ID: %v", result.Error) + } + return &p, nil +} + +func (s *SqlStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) { + var count int64 + result := s.db.Model(&proxy.Proxy{}).Where("account_id = ?", accountID).Count(&count) + if result.Error != nil { + return 0, status.Errorf(status.Internal, "count proxies by account ID: %v", result.Error) + } + return count, nil +} + +func (s *SqlStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) { + var count int64 + result := s.db. + Model(&proxy.Proxy{}). + Where("cluster_address = ? AND (account_id IS NULL OR account_id != ?)", clusterAddress, accountID). + Count(&count) + if result.Error != nil { + return false, status.Errorf(status.Internal, "check cluster address conflict: %v", result.Error) + } + return count > 0, nil +} + +func (s *SqlStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { + result := s.db. + Where("cluster_address = ? AND account_id = ?", clusterAddress, accountID). + Delete(&proxy.Proxy{}) + if result.Error != nil { + return status.Errorf(status.Internal, "delete account cluster: %v", result.Error) + } + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "cluster not found") + } + return nil +} + +func (s *SqlStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) { var clusters []proxy.Cluster result := s.db.Model(&proxy.Proxy{}). - Select("cluster_address as address, COUNT(*) as connected_proxies"). - Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). + Select("MIN(id) as id, cluster_address as address, COUNT(*) as connected_proxies, COUNT(account_id) > 0 as self_hosted"). + Where("status = ? AND last_seen > ? AND (account_id IS NULL OR account_id = ?)", + proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold), accountID). Group("cluster_address"). Scan(&clusters) diff --git a/management/server/store/store.go b/management/server/store/store.go index db98bc644..aa601c33f 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -114,6 +114,9 @@ type Store interface { GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) + GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) + GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) + IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error RevokeProxyAccessToken(ctx context.Context, tokenID string) error MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error @@ -288,11 +291,16 @@ type Store interface { DisconnectProxy(ctx context.Context, proxyID, sessionID string) error UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) - GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) + GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) + GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error + GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) + CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) + IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) + DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) @@ -496,6 +504,9 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key") }, + func(db *gorm.DB) error { + return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique") + }, } } diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 6c2c9bbc3..9780c521e 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -166,20 +166,6 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration) } -// GetClusterSupportsCrowdSec mocks base method. -func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr) - ret0, _ := ret[0].(*bool) - return ret0 -} - -// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec. -func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr) -} - // Close mocks base method. func (m *MockStore) Close(ctx context.Context) error { m.ctrl.T.Helper() @@ -238,6 +224,21 @@ func (mr *MockStoreMockRecorder) CountEphemeralServicesByPeer(ctx, lockStrength, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID) } +// CountProxiesByAccountID mocks base method. +func (m *MockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountProxiesByAccountID", ctx, accountID) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountProxiesByAccountID indicates an expected call of CountProxiesByAccountID. +func (mr *MockStoreMockRecorder) CountProxiesByAccountID(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountProxiesByAccountID", reflect.TypeOf((*MockStore)(nil).CountProxiesByAccountID), ctx, accountID) +} + // CreateAccessLog mocks base method. func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error { m.ctrl.T.Helper() @@ -576,6 +577,20 @@ func (mr *MockStoreMockRecorder) DeletePostureChecks(ctx, accountID, postureChec return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePostureChecks", reflect.TypeOf((*MockStore)(nil).DeletePostureChecks), ctx, accountID, postureChecksID) } +// DeleteAccountCluster mocks base method. +func (m *MockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAccountCluster indicates an expected call of DeleteAccountCluster. +func (mr *MockStoreMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockStore)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID) +} + // DeleteRoute mocks base method. func (m *MockStore) DeleteRoute(ctx context.Context, accountID, routeID string) error { m.ctrl.T.Helper() @@ -1302,19 +1317,34 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx) } -// GetActiveProxyClusters mocks base method. -func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) { +// GetActiveProxyClusterAddressesForAccount mocks base method. +func (m *MockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx) + ret := m.ctrl.Call(m, "GetActiveProxyClusterAddressesForAccount", ctx, accountID) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveProxyClusterAddressesForAccount indicates an expected call of GetActiveProxyClusterAddressesForAccount. +func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddressesForAccount", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddressesForAccount), ctx, accountID) +} + +// GetActiveProxyClusters mocks base method. +func (m *MockStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx, accountID) ret0, _ := ret[0].([]proxy.Cluster) ret1, _ := ret[1].(error) return ret0, ret1 } // GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters. -func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx, accountID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx, accountID) } // GetAllAccounts mocks base method. @@ -1390,6 +1420,20 @@ func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr) } +// GetClusterSupportsCrowdSec mocks base method. +func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec. +func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr) +} + // GetClusterSupportsCustomPorts mocks base method. func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { m.ctrl.T.Helper() @@ -1959,6 +2003,51 @@ func (mr *MockStoreMockRecorder) GetProxyAccessTokenByHashedToken(ctx, lockStren return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByHashedToken", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByHashedToken), ctx, lockStrength, hashedToken) } +// GetProxyAccessTokenByID mocks base method. +func (m *MockStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types2.ProxyAccessToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxyAccessTokenByID", ctx, lockStrength, tokenID) + ret0, _ := ret[0].(*types2.ProxyAccessToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProxyAccessTokenByID indicates an expected call of GetProxyAccessTokenByID. +func (mr *MockStoreMockRecorder) GetProxyAccessTokenByID(ctx, lockStrength, tokenID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByID), ctx, lockStrength, tokenID) +} + +// GetProxyAccessTokensByAccountID mocks base method. +func (m *MockStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types2.ProxyAccessToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxyAccessTokensByAccountID", ctx, lockStrength, accountID) + ret0, _ := ret[0].([]*types2.ProxyAccessToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProxyAccessTokensByAccountID indicates an expected call of GetProxyAccessTokensByAccountID. +func (mr *MockStoreMockRecorder) GetProxyAccessTokensByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokensByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokensByAccountID), ctx, lockStrength, accountID) +} + +// GetProxyByAccountID mocks base method. +func (m *MockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxyByAccountID", ctx, accountID) + ret0, _ := ret[0].(*proxy.Proxy) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProxyByAccountID indicates an expected call of GetProxyByAccountID. +func (mr *MockStoreMockRecorder) GetProxyByAccountID(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyByAccountID), ctx, accountID) +} + // GetResourceGroups mocks base method. func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) { m.ctrl.T.Helper() @@ -2391,6 +2480,21 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID) } +// IsClusterAddressConflicting mocks base method. +func (m *MockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsClusterAddressConflicting", ctx, clusterAddress, accountID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsClusterAddressConflicting indicates an expected call of IsClusterAddressConflicting. +func (mr *MockStoreMockRecorder) IsClusterAddressConflicting(ctx, clusterAddress, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressConflicting", reflect.TypeOf((*MockStore)(nil).IsClusterAddressConflicting), ctx, clusterAddress, accountID) +} + // IsPrimaryAccount mocks base method. func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) { m.ctrl.T.Helper() @@ -2407,6 +2511,21 @@ func (mr *MockStoreMockRecorder) IsPrimaryAccount(ctx, accountID interface{}) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPrimaryAccount", reflect.TypeOf((*MockStore)(nil).IsPrimaryAccount), ctx, accountID) } +// IsProxyAccessTokenValid mocks base method. +func (m *MockStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsProxyAccessTokenValid", ctx, tokenID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsProxyAccessTokenValid indicates an expected call of IsProxyAccessTokenValid. +func (mr *MockStoreMockRecorder) IsProxyAccessTokenValid(ctx, tokenID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsProxyAccessTokenValid", reflect.TypeOf((*MockStore)(nil).IsProxyAccessTokenValid), ctx, tokenID) +} + // ListCustomDomains mocks base method. func (m *MockStore) ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error) { m.ctrl.T.Helper() diff --git a/proxy/management_byop_integration_test.go b/proxy/management_byop_integration_test.go new file mode 100644 index 000000000..c0fbe682a --- /dev/null +++ b/proxy/management_byop_integration_test.go @@ -0,0 +1,409 @@ +package proxy + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/base64" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" + grpcstatus "google.golang.org/grpc/status" + + proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + nbcache "github.com/netbirdio/netbird/management/server/cache" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/management/proto" +) + +type byopTestSetup struct { + store store.Store + proxyService *nbgrpc.ProxyServiceServer + grpcServer *grpc.Server + grpcAddr string + cleanup func() + + accountA string + accountB string + accountAToken types.PlainProxyToken + accountBToken types.PlainProxyToken + accountACluster string + accountBCluster string +} + +func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup { + t.Helper() + ctx := context.Background() + + testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + + accountAID := "byop-account-a" + accountBID := "byop-account-b" + + for _, acc := range []*types.Account{ + {Id: accountAID, Domain: "a.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()}, + {Id: accountBID, Domain: "b.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()}, + } { + require.NoError(t, testStore.SaveAccount(ctx, acc)) + } + + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + pubKey := base64.StdEncoding.EncodeToString(pub) + privKey := base64.StdEncoding.EncodeToString(priv) + + clusterA := "byop-a.proxy.test" + clusterB := "byop-b.proxy.test" + + services := []*service.Service{ + { + ID: "svc-a1", AccountID: accountAID, Name: "App A1", + Domain: "app1." + clusterA, ProxyCluster: clusterA, Enabled: true, + SessionPrivateKey: privKey, SessionPublicKey: pubKey, + Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.1", Port: 8080, Protocol: "http", TargetId: "peer-a1", TargetType: "peer", Enabled: true}}, + }, + { + ID: "svc-a2", AccountID: accountAID, Name: "App A2", + Domain: "app2." + clusterA, ProxyCluster: clusterA, Enabled: true, + SessionPrivateKey: privKey, SessionPublicKey: pubKey, + Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.2", Port: 8080, Protocol: "http", TargetId: "peer-a2", TargetType: "peer", Enabled: true}}, + }, + { + ID: "svc-b1", AccountID: accountBID, Name: "App B1", + Domain: "app1." + clusterB, ProxyCluster: clusterB, Enabled: true, + SessionPrivateKey: privKey, SessionPublicKey: pubKey, + Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.3", Port: 8080, Protocol: "http", TargetId: "peer-b1", TargetType: "peer", Enabled: true}}, + }, + } + for _, svc := range services { + require.NoError(t, testStore.CreateService(ctx, svc)) + } + + tokenA, err := types.CreateNewProxyAccessToken("byop-token-a", 0, &accountAID, "admin-a") + require.NoError(t, err) + require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenA.ProxyAccessToken)) + + tokenB, err := types.CreateNewProxyAccessToken("byop-token-b", 0, &accountBID, "admin-b") + require.NoError(t, err) + require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenB.ProxyAccessToken)) + + cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) + + meter := noop.NewMeterProvider().Meter("test") + realProxyManager, err := proxymanager.NewManager(testStore, meter) + require.NoError(t, err) + + oidcConfig := nbgrpc.ProxyOIDCConfig{ + Issuer: "https://fake-issuer.example.com", + ClientID: "test-client", + HMACKey: []byte("test-hmac-key"), + } + + usersManager := users.NewManager(testStore) + + proxyService := nbgrpc.NewProxyServiceServer( + &testAccessLogManager{}, + tokenStore, + pkceStore, + oidcConfig, + nil, + usersManager, + realProxyManager, + nil, + ) + + svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore} + proxyService.SetServiceManager(svcMgr) + + proxyController := &testProxyController{} + proxyService.SetProxyController(proxyController) + + _, streamInterceptor, authClose := nbgrpc.NewProxyAuthInterceptors(testStore) + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + grpcServer := grpc.NewServer(grpc.StreamInterceptor(streamInterceptor)) + proto.RegisterProxyServiceServer(grpcServer, proxyService) + + go func() { + if err := grpcServer.Serve(lis); err != nil { + t.Logf("gRPC server error: %v", err) + } + }() + + return &byopTestSetup{ + store: testStore, + proxyService: proxyService, + grpcServer: grpcServer, + grpcAddr: lis.Addr().String(), + cleanup: func() { + grpcServer.GracefulStop() + authClose() + storeCleanup() + }, + accountA: accountAID, + accountB: accountBID, + accountAToken: tokenA.PlainToken, + accountBToken: tokenB.PlainToken, + accountACluster: clusterA, + accountBCluster: clusterB, + } +} + +func byopContext(ctx context.Context, token types.PlainProxyToken) context.Context { + md := metadata.Pairs("authorization", "Bearer "+string(token)) + return metadata.NewOutgoingContext(ctx, md) +} + +func receiveBYOPMappings(t *testing.T, stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping { + t.Helper() + var mappings []*proto.ProxyMapping + for { + msg, err := stream.Recv() + require.NoError(t, err) + mappings = append(mappings, msg.GetMapping()...) + if msg.GetInitialSyncComplete() { + break + } + } + return mappings +} + +func TestIntegration_BYOPProxy_ReceivesOnlyAccountServices(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel() + + stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-a", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + mappings := receiveBYOPMappings(t, stream) + + assert.Len(t, mappings, 2, "BYOP proxy should receive only account A's 2 services") + for _, m := range mappings { + assert.Equal(t, setup.accountA, m.GetAccountId(), "all mappings should belong to account A") + t.Logf("received mapping: id=%s domain=%s account=%s", m.GetId(), m.GetDomain(), m.GetAccountId()) + } + + ids := map[string]bool{} + for _, m := range mappings { + ids[m.GetId()] = true + } + assert.True(t, ids["svc-a1"], "should contain svc-a1") + assert.True(t, ids["svc-a2"], "should contain svc-a2") + assert.False(t, ids["svc-b1"], "should NOT contain account B's svc-b1") +} + +func TestIntegration_BYOPProxy_AccountBReceivesOnlyItsServices(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second) + defer cancel() + + stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-b", + Version: "test-v1", + Address: setup.accountBCluster, + }) + require.NoError(t, err) + + mappings := receiveBYOPMappings(t, stream) + + assert.Len(t, mappings, 1, "BYOP proxy B should receive only 1 service") + assert.Equal(t, "svc-b1", mappings[0].GetId()) + assert.Equal(t, setup.accountB, mappings[0].GetAccountId()) +} + +func TestIntegration_BYOPProxy_MultiplePerAccount(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel1() + + stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-a-first", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + mappings1 := receiveBYOPMappings(t, stream1) + assert.Len(t, mappings1, 2, "first BYOP proxy should receive account A's 2 services") + + ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel2() + + stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-a-second", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + mappings2 := receiveBYOPMappings(t, stream2) + assert.Len(t, mappings2, 2, "second BYOP proxy from same account should also receive the 2 services") + for _, m := range mappings2 { + assert.Equal(t, setup.accountA, m.GetAccountId()) + } +} + +func TestIntegration_BYOPProxy_ClusterAddressConflict(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel1() + + stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-a-cluster", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + _ = receiveBYOPMappings(t, stream1) + + ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second) + defer cancel2() + + stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-b-conflict", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + _, err = stream2.Recv() + require.Error(t, err) + + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.AlreadyExists, st.Code(), "cluster address conflict should return AlreadyExists") + t.Logf("expected rejection: %s", st.Message()) +} + +func TestIntegration_BYOPProxy_SameProxyReconnects(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + proxyID := "byop-proxy-reconnect" + + ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{ + ProxyId: proxyID, + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + firstMappings := receiveBYOPMappings(t, stream1) + cancel1() + + time.Sleep(200 * time.Millisecond) + + ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel2() + + stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{ + ProxyId: proxyID, + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + secondMappings := receiveBYOPMappings(t, stream2) + + assert.Equal(t, len(firstMappings), len(secondMappings), "reconnect should receive same mappings") + + firstIDs := map[string]bool{} + for _, m := range firstMappings { + firstIDs[m.GetId()] = true + } + for _, m := range secondMappings { + assert.True(t, firstIDs[m.GetId()], "mapping %s should be present on reconnect", m.GetId()) + } +} + +func TestIntegration_BYOPProxy_UnauthenticatedRejected(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: "no-auth-proxy", + Version: "test-v1", + Address: "some.cluster.io", + }) + require.NoError(t, err) + + _, err = stream.Recv() + require.Error(t, err) + + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Unauthenticated, st.Code()) +} diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 99bbdad0c..9fd3d2ce9 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "encoding/base64" "errors" + "fmt" "net" "sync" "sync/atomic" @@ -140,6 +141,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { nil, usersManager, proxyManager, + nil, ) // Use store-backed service manager @@ -201,8 +203,8 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, // testProxyManager is a mock implementation of proxy.Manager for testing. type testProxyManager struct{} -func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) { - return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: "connected"}, nil +func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) { + return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: nbproxy.StatusConnected}, nil } func (m *testProxyManager) Disconnect(_ context.Context, _, _ string) error { @@ -217,6 +219,10 @@ func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]strin return nil, nil } +func (m *testProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) { + return nil, nil +} + func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Cluster, error) { return nil, nil } @@ -237,6 +243,22 @@ func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) erro return nil } +func (m *testProxyManager) GetAccountProxy(_ context.Context, accountID string) (*nbproxy.Proxy, error) { + return nil, fmt.Errorf("proxy not found for account %s", accountID) +} + +func (m *testProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) { + return 0, nil +} + +func (m *testProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) { + return true, nil +} + +func (m *testProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error { + return nil +} + // testProxyController is a mock implementation of rpservice.ProxyController for testing. type testProxyController struct{} @@ -290,6 +312,10 @@ func (m *storeBackedServiceManager) DeleteService(ctx context.Context, accountID return nil } +func (m *storeBackedServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error { + return nil +} + func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error { return nil } @@ -336,6 +362,10 @@ func (m *storeBackedServiceManager) StopServiceFromPeer(_ context.Context, _, _, func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {} +func (m *storeBackedServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + return m.store.GetServiceByDomain(ctx, domain) +} + func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { return nil, nil } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 82fca0782..942f3aa45 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -3355,10 +3355,64 @@ components: example: false required: - enabled + ProxyTokenRequest: + type: object + properties: + name: + type: string + description: Human-readable token name + example: "my-proxy-token" + expires_in: + type: integer + minimum: 0 + description: Token expiration in seconds (0 = never expires) + example: 0 + required: + - name + ProxyToken: + type: object + properties: + id: + type: string + name: + type: string + expires_at: + type: string + format: date-time + created_at: + type: string + format: date-time + last_used: + type: string + format: date-time + revoked: + type: boolean + required: + - id + - name + - created_at + - revoked + ProxyTokenCreated: + type: object + description: Returned on creation — plain_token is shown only once + allOf: + - $ref: '#/components/schemas/ProxyToken' + - type: object + properties: + plain_token: + type: string + description: The plain text token (shown only once) + example: "nbx_abc123..." + required: + - plain_token ProxyCluster: type: object description: A proxy cluster represents a group of proxy nodes serving the same address properties: + id: + type: string + description: Unique identifier of a proxy in this cluster + example: "chlfq4q5r8kc73b0qjpg" address: type: string description: Cluster address used for CNAME targets @@ -3367,9 +3421,15 @@ components: type: integer description: Number of proxy nodes connected in this cluster example: 3 + self_hosted: + type: boolean + description: Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner + example: false required: + - id - address - connected_proxies + - self_hosted ReverseProxyDomainType: type: string description: Type of Reverse Proxy Domain @@ -11375,6 +11435,111 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/reverse-proxies/clusters/{clusterAddress}: + delete: + summary: Delete a self-hosted proxy cluster + description: Removes all self-hosted (BYOP) proxy registrations for the given cluster address owned by the account. + tags: [ Services ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: clusterAddress + required: true + schema: + type: string + description: The address of the proxy cluster + responses: + '200': + description: Proxy cluster deleted successfully + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + /api/reverse-proxies/proxy-tokens: + get: + summary: List Proxy Tokens + description: Returns all proxy access tokens for the account + tags: [ Self-Hosted Proxies ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of proxy tokens + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/ProxyToken' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a Proxy Token + description: Generate an account-scoped proxy access token for self-hosted proxy registration + tags: [ Self-Hosted Proxies ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ProxyTokenRequest' + responses: + '200': + description: Proxy token created (plain token shown once) + content: + application/json: + schema: + $ref: '#/components/schemas/ProxyTokenCreated' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/reverse-proxies/proxy-tokens/{tokenId}: + delete: + summary: Revoke a Proxy Token + description: Revoke an account-scoped proxy access token + tags: [ Self-Hosted Proxies ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: tokenId + required: true + schema: + type: string + description: The unique identifier of the proxy token + responses: + '200': + description: Token revoked + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" /api/reverse-proxies/services: get: summary: List all Services diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 4b94ea01c..b3bb475a9 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -3785,11 +3785,49 @@ type ProxyAccessLogsResponse struct { // ProxyCluster A proxy cluster represents a group of proxy nodes serving the same address type ProxyCluster struct { + // Id Unique identifier of a proxy in this cluster + Id string `json:"id"` + // Address Cluster address used for CNAME targets Address string `json:"address"` // ConnectedProxies Number of proxy nodes connected in this cluster ConnectedProxies int `json:"connected_proxies"` + + // SelfHosted Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner + SelfHosted bool `json:"self_hosted"` +} + +// ProxyToken defines model for ProxyToken. +type ProxyToken struct { + CreatedAt time.Time `json:"created_at"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Id string `json:"id"` + LastUsed *time.Time `json:"last_used,omitempty"` + Name string `json:"name"` + Revoked bool `json:"revoked"` +} + +// ProxyTokenCreated defines model for ProxyTokenCreated. +type ProxyTokenCreated struct { + CreatedAt time.Time `json:"created_at"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Id string `json:"id"` + LastUsed *time.Time `json:"last_used,omitempty"` + Name string `json:"name"` + + // PlainToken The plain text token (shown only once) + PlainToken string `json:"plain_token"` + Revoked bool `json:"revoked"` +} + +// ProxyTokenRequest defines model for ProxyTokenRequest. +type ProxyTokenRequest struct { + // ExpiresIn Token expiration in seconds (0 = never expires) + ExpiresIn *int `json:"expires_in,omitempty"` + + // Name Human-readable token name + Name string `json:"name"` } // Resource defines model for Resource. @@ -5160,6 +5198,9 @@ type PutApiPostureChecksPostureCheckIdJSONRequestBody = PostureCheckUpdate // PostApiReverseProxiesDomainsJSONRequestBody defines body for PostApiReverseProxiesDomains for application/json ContentType. type PostApiReverseProxiesDomainsJSONRequestBody = ReverseProxyDomainRequest +// PostApiReverseProxiesProxyTokensJSONRequestBody defines body for PostApiReverseProxiesProxyTokens for application/json ContentType. +type PostApiReverseProxiesProxyTokensJSONRequestBody = ProxyTokenRequest + // PostApiReverseProxiesServicesJSONRequestBody defines body for PostApiReverseProxiesServices for application/json ContentType. type PostApiReverseProxiesServicesJSONRequestBody = ServiceRequest From 946ce4c3da24126ae34cc3b482277923863cbddb Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 12 May 2026 00:48:21 +0900 Subject: [PATCH 02/17] [client] Fix --config flag default to point at profile path (#6122) --- client/cmd/root.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/cmd/root.go b/client/cmd/root.go index 29d4328a1..0a0aa4197 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -143,7 +143,7 @@ func init() { rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets WireGuard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output") - rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Overrides the default profile file location") + rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", profilemanager.DefaultConfigPath, "Overrides the default profile file location") rootCmd.AddCommand(upCmd) rootCmd.AddCommand(downCmd) From 96672dd1f8b116d2c40339bb94c61ed812c508e0 Mon Sep 17 00:00:00 2001 From: Nicolas Frati Date: Tue, 12 May 2026 13:50:35 +0200 Subject: [PATCH 03/17] [management] chores: update dex version (#6124) * chores: update dex version * chore: update dex fork --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 5704887ce..7c1a95e79 100644 --- a/go.mod +++ b/go.mod @@ -341,8 +341,8 @@ replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801 replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 -replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.1-0.20260415145816-a0c6b40ff9f2 +replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1 -replace github.com/dexidp/dex/api/v2 => github.com/netbirdio/dex/api/v2 v2.0.0-20260415145816-a0c6b40ff9f2 +replace github.com/dexidp/dex/api/v2 => github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1 replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0 diff --git a/go.sum b/go.sum index 42652169c..53789f49d 100644 --- a/go.sum +++ b/go.sum @@ -485,10 +485,10 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/netbirdio/dex v0.244.1-0.20260415145816-a0c6b40ff9f2 h1:AP7OM/JnTogod3rVcLsMuilSG94kWQCr3z6R4rfVXnc= -github.com/netbirdio/dex v0.244.1-0.20260415145816-a0c6b40ff9f2/go.mod h1:+trSlzHNmdJGvz0oLEyyiuaPstUeD7YO6B3Fx9nyziY= -github.com/netbirdio/dex/api/v2 v2.0.0-20260415145816-a0c6b40ff9f2 h1:HEEGJPsVw7/p7SEL3HWP4vaInxHo8OJSEaOkHpUAk+M= -github.com/netbirdio/dex/api/v2 v2.0.0-20260415145816-a0c6b40ff9f2/go.mod h1:awuTyT29CYALpEyET0S307EgNlPWrc7fFKRAyhsO45M= +github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1 h1:4TaYr9O4xX0D2kszeOLclTiCbA3eHq3xWV+9ILJbIYs= +github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1/go.mod h1:IHH+H8vK2GfqtIt5u/5OdPh18yk0oDHuj2vz5+Goetg= +github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1 h1:neE7z+FPUkldl3faK/Jt+hJK2L+1XfQ1W33TQhU9m88= +github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1/go.mod h1:awuTyT29CYALpEyET0S307EgNlPWrc7fFKRAyhsO45M= github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus= github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= From 1224d6e1eeb04d6d43ca05f3bdbdc697b0b7a182 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 12 May 2026 21:52:56 +0900 Subject: [PATCH 04/17] [client] Persist management URL and pre-shared key overrides on login (#6065) --- client/server/login_overrides_test.go | 93 +++++++++++++++++++++++++++ client/server/server.go | 33 +++++++++- 2 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 client/server/login_overrides_test.go diff --git a/client/server/login_overrides_test.go b/client/server/login_overrides_test.go new file mode 100644 index 000000000..c45557c59 --- /dev/null +++ b/client/server/login_overrides_test.go @@ -0,0 +1,93 @@ +package server + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/profilemanager" +) + +func TestPersistLoginOverrides(t *testing.T) { + strPtr := func(s string) *string { return &s } + + tests := []struct { + name string + initialMgmtURL string + initialPSK string + newMgmtURL string + newPSK *string + wantMgmtURL string + wantPSK string + }{ + { + name: "persist new management URL", + initialMgmtURL: "https://old.example.com:33073", + newMgmtURL: "https://new.example.com:33073", + wantMgmtURL: "https://new.example.com:33073", + }, + { + name: "persist new pre-shared key", + initialMgmtURL: "https://existing.example.com:33073", + initialPSK: "old-key", + newPSK: strPtr("new-key"), + wantMgmtURL: "https://existing.example.com:33073", + wantPSK: "new-key", + }, + { + name: "persist both", + initialMgmtURL: "https://old.example.com:33073", + initialPSK: "old-key", + newMgmtURL: "https://new.example.com:33073", + newPSK: strPtr("new-key"), + wantMgmtURL: "https://new.example.com:33073", + wantPSK: "new-key", + }, + { + name: "no inputs preserves existing", + initialMgmtURL: "https://existing.example.com:33073", + initialPSK: "existing-key", + wantMgmtURL: "https://existing.example.com:33073", + wantPSK: "existing-key", + }, + { + name: "empty PSK pointer is ignored", + initialMgmtURL: "https://existing.example.com:33073", + initialPSK: "existing-key", + newPSK: strPtr(""), + wantMgmtURL: "https://existing.example.com:33073", + wantPSK: "existing-key", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + origDefault := profilemanager.DefaultConfigPath + t.Cleanup(func() { profilemanager.DefaultConfigPath = origDefault }) + + dir := t.TempDir() + profilemanager.DefaultConfigPath = filepath.Join(dir, "default.json") + + seed := profilemanager.ConfigInput{ + ConfigPath: profilemanager.DefaultConfigPath, + ManagementURL: tt.initialMgmtURL, + } + if tt.initialPSK != "" { + seed.PreSharedKey = strPtr(tt.initialPSK) + } + _, err := profilemanager.UpdateOrCreateConfig(seed) + require.NoError(t, err, "seed config") + + activeProf := &profilemanager.ActiveProfileState{Name: "default"} + err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK) + require.NoError(t, err, "persistLoginOverrides") + + cfg, err := profilemanager.ReadConfig(profilemanager.DefaultConfigPath) + require.NoError(t, err, "read back config") + + require.Equal(t, tt.wantMgmtURL, cfg.ManagementURL.String(), "management URL") + require.Equal(t, tt.wantPSK, cfg.PreSharedKey, "pre-shared key") + }) + } +} diff --git a/client/server/server.go b/client/server/server.go index bc8de8f9f..397fb37e4 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -490,6 +490,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.mutex.Unlock() + if err := persistLoginOverrides(activeProf, msg.ManagementUrl, msg.OptionalPreSharedKey); err != nil { + log.Errorf("failed to persist login overrides: %v", err) + return nil, fmt.Errorf("persist login overrides: %w", err) + } + config, _, err := s.getConfig(activeProf) if err != nil { log.Errorf("failed to get active profile config: %v", err) @@ -964,7 +969,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe return &proto.LogoutResponse{}, nil } -// GetConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist +// getConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, bool, error) { cfgPath, err := activeProf.FilePath() if err != nil { @@ -1766,3 +1771,29 @@ func sendTerminalNotification() error { return wallCmd.Wait() } + +// persistLoginOverrides writes management URL and pre-shared key from a LoginRequest to the +// active profile config so that subsequent reads pick them up. Empty/nil values are ignored. +func persistLoginOverrides(activeProf *profilemanager.ActiveProfileState, managementURL string, preSharedKey *string) error { + if preSharedKey != nil && *preSharedKey == "" { + preSharedKey = nil + } + if managementURL == "" && preSharedKey == nil { + return nil + } + + cfgPath, err := activeProf.FilePath() + if err != nil { + return fmt.Errorf("active profile file path: %w", err) + } + + input := profilemanager.ConfigInput{ + ConfigPath: cfgPath, + ManagementURL: managementURL, + PreSharedKey: preSharedKey, + } + if _, err := profilemanager.UpdateOrCreateConfig(input); err != nil { + return fmt.Errorf("update config: %w", err) + } + return nil +} From 9126a192ca3f6845b798c948e1b1bc05bb7db965 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 12 May 2026 22:05:53 +0900 Subject: [PATCH 05/17] [client] Set 0644 perms on SSH client config after os.CreateTemp (#6126) --- client/ssh/config/manager.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go index b58bf2233..20695cb4d 100644 --- a/client/ssh/config/manager.go +++ b/client/ssh/config/manager.go @@ -252,6 +252,10 @@ func (m *Manager) writeSSHConfig(sshConfig string) error { return fmt.Errorf("write SSH config file %s: %w", tmpPath, err) } + if err := os.Chmod(tmpPath, 0644); err != nil { + return fmt.Errorf("chmod SSH config file %s: %w", tmpPath, err) + } + if err := os.Rename(tmpPath, sshConfigPath); err != nil { return fmt.Errorf("rename SSH config %s -> %s: %w", tmpPath, sshConfigPath, err) } From ab2a8794e7a41693fff303725c96399ca190e8ff Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 14 May 2026 12:30:42 +0200 Subject: [PATCH 06/17] [client] Add short flags for status command options (#6137) * [client] Add short flags for status command options * uppercase filters --- client/cmd/status.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/client/cmd/status.go b/client/cmd/status.go index dae30e854..103b3044a 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -43,16 +43,16 @@ func init() { ipsFilterMap = make(map[string]struct{}) prefixNamesFilterMap = make(map[string]struct{}) statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format") - statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format") - statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format") - statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33") - statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer") + statusCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display detailed status information in json format") + statusCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display detailed status information in yaml format") + statusCmd.PersistentFlags().BoolVarP(&ipv4Flag, "ipv4", "4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33") + statusCmd.PersistentFlags().BoolVarP(&ipv6Flag, "ipv6", "6", false, "display only NetBird IPv6 of this peer") statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6") - statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1") - statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") - statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") - statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P") - statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)") + statusCmd.PersistentFlags().StringSliceVarP(&ipsFilter, "filter-by-ips", "I", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1") + statusCmd.PersistentFlags().StringSliceVarP(&prefixNamesFilter, "filter-by-names", "N", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") + statusCmd.PersistentFlags().StringVarP(&statusFilter, "filter-by-status", "S", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") + statusCmd.PersistentFlags().StringVarP(&connectionTypeFilter, "filter-by-connection-type", "T", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P") + statusCmd.PersistentFlags().StringVarP(&checkFlag, "check", "C", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)") } func statusFunc(cmd *cobra.Command, args []string) error { From 77b479286e399660ef2bdcbe7983363946660574 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Thu, 14 May 2026 13:27:50 +0200 Subject: [PATCH 07/17] [management] fix offline statuses for public proxy clusters (#6133) --- .../reverseproxy/domain/manager/manager.go | 23 ++++++- .../domain/manager/manager_test.go | 60 ++++++++++++++++--- .../reverseproxy/service/manager/manager.go | 59 +++++++++--------- management/server/store/sql_store.go | 28 +++++++-- management/server/store/sql_store_test.go | 52 ++++++++++++++++ 5 files changed, 177 insertions(+), 45 deletions(-) diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index ab899e0bf..2790b5f20 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -304,10 +304,27 @@ func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]s if err != nil { return nil, fmt.Errorf("get BYOP cluster addresses: %w", err) } - if len(byopAddresses) > 0 { - return byopAddresses, nil + publicAddresses, err := m.proxyManager.GetActiveClusterAddresses(ctx) + if err != nil { + return nil, fmt.Errorf("get public cluster addresses: %w", err) } - return m.proxyManager.GetActiveClusterAddresses(ctx) + seen := make(map[string]struct{}, len(byopAddresses)+len(publicAddresses)) + merged := make([]string, 0, len(byopAddresses)+len(publicAddresses)) + for _, addr := range byopAddresses { + if _, ok := seen[addr]; ok { + continue + } + seen[addr] = struct{}{} + merged = append(merged, addr) + } + for _, addr := range publicAddresses { + if _, ok := seen[addr]; ok { + continue + } + seen[addr] = struct{}{} + merged = append(merged, addr) + } + return merged, nil } func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) { diff --git a/management/internals/modules/reverseproxy/domain/manager/manager_test.go b/management/internals/modules/reverseproxy/domain/manager/manager_test.go index fdeb0765f..5e7bbfc36 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager_test.go @@ -40,22 +40,37 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) return nil } -func TestGetClusterAllowList_BYOPProxy(t *testing.T) { +func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) { pm := &mockProxyManager{ getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) { assert.Equal(t, "acc-123", accID) return []string{"byop.example.com"}, nil }, getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { - t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist") - return nil, nil + return []string{"eu.proxy.netbird.io"}, nil }, } mgr := Manager{proxyManager: pm} result, err := mgr.getClusterAllowList(context.Background(), "acc-123") require.NoError(t, err) - assert.Equal(t, []string{"byop.example.com"}, result) + assert.Equal(t, []string{"byop.example.com", "eu.proxy.netbird.io"}, result) +} + +func TestGetClusterAllowList_DeduplicatesBYOPAndPublic(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return []string{"shared.example.com", "byop.example.com"}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return []string{"shared.example.com", "eu.proxy.netbird.io"}, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"shared.example.com", "byop.example.com", "eu.proxy.netbird.io"}, result) } func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) { @@ -79,10 +94,6 @@ func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) { getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { return nil, errors.New("db error") }, - getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { - t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails") - return nil, nil - }, } mgr := Manager{proxyManager: pm} @@ -92,6 +103,23 @@ func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) { assert.Contains(t, err.Error(), "BYOP cluster addresses") } +func TestGetClusterAllowList_PublicError_ReturnsError(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return []string{"byop.example.com"}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return nil, errors.New("db error") + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "public cluster addresses") +} + func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) { pm := &mockProxyManager{ getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { @@ -108,3 +136,19 @@ func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) { assert.Equal(t, []string{"eu.proxy.netbird.io"}, result) } +func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return []string{"byop.example.com"}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return nil, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"byop.example.com"}, result) +} + diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index c866d8f75..4a8598afb 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -306,6 +306,10 @@ func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, clus func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error { customPorts := m.clusterCustomPorts(ctx, svc) + if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil { + return err + } + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if svc.Domain != "" { if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { @@ -321,10 +325,6 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc * return err } - if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil { - return err - } - if err := transaction.CreateService(ctx, svc); err != nil { return fmt.Errorf("create service: %w", err) } @@ -435,6 +435,10 @@ func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error { customPorts := m.clusterCustomPorts(ctx, svc) + if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil { + return err + } + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil { return err @@ -448,10 +452,6 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee return err } - if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil { - return err - } - if err := transaction.CreateService(ctx, svc); err != nil { return fmt.Errorf("create service: %w", err) } @@ -552,10 +552,22 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se svcForCaps.ProxyCluster = effectiveCluster customPorts := m.clusterCustomPorts(ctx, &svcForCaps) + if err := validateTargetReferences(ctx, m.store, accountID, service.Targets); err != nil { + return nil, err + } + + // Validate subdomain requirement *before* the transaction: the underlying + // capability lookup talks to the main DB pool, and SQLite's single-connection + // pool would self-deadlock if this ran while the tx already held the only + // connection. + if err := m.validateSubdomainRequirement(ctx, service.Domain, effectiveCluster); err != nil { + return nil, err + } + var updateInfo serviceUpdateInfo err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts) + return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts, effectiveCluster) }) return &updateInfo, err @@ -585,7 +597,7 @@ func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string, return existing.ProxyCluster, nil } -func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error { +func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool, effectiveCluster string) error { existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID) if err != nil { return err @@ -603,17 +615,13 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St updateInfo.domainChanged = existingService.Domain != service.Domain if updateInfo.domainChanged { - if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil { + if err := m.handleDomainChange(ctx, transaction, service, effectiveCluster); err != nil { return err } } else { service.ProxyCluster = existingService.ProxyCluster } - if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil { - return err - } - m.preserveExistingAuthSecrets(service, existingService) if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil { return err @@ -628,9 +636,6 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St if err := m.checkPortConflict(ctx, transaction, service); err != nil { return err } - if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil { - return err - } if err := transaction.UpdateService(ctx, service); err != nil { return fmt.Errorf("update service: %w", err) } @@ -638,20 +643,18 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St return nil } -func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error { +// handleDomainChange validates the new domain is free inside the transaction +// and applies the pre-resolved cluster (computed outside the tx by +// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks +// to the main DB pool and would self-deadlock under SQLite (max_open_conns=1) +// because the transaction already holds the only connection. +func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, svc *service.Service, effectiveCluster string) error { if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil { return err } - - if m.clusterDeriver != nil { - newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain) - if err != nil { - log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain) - } else { - svc.ProxyCluster = newCluster - } + if effectiveCluster != "" { + svc.ProxyCluster = effectiveCluster } - return nil } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 4c2f0be52..893ee2168 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/netip" + "net/url" "os" "path/filepath" "runtime" @@ -2794,12 +2795,27 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe connStr = filepath.Join(dataDir, filePath) } - // Append query parameters: user-provided take precedence, otherwise default to cache=shared on non-Windows - if hasQuery { - connStr += "?" + query - } else if runtime.GOOS != "windows" { + // Compose query parameters. User-provided ?_busy_timeout (or its mattn alias + // ?_timeout) overrides our default; otherwise inject 30s so SQLite waits at + // most that long on a lock instead of blocking the only Go-side connection. + // mattn/go-sqlite3 applies PRAGMA from the DSN on every fresh connection, so + // the value survives ConnMaxIdleTime/ConnMaxLifetime recycling. cache=shared + // stays the default on non-Windows for the same reason as before. + parsed, _ := url.ParseQuery(query) + var defaults []string + if parsed.Get("_busy_timeout") == "" && parsed.Get("_timeout") == "" { + defaults = append(defaults, "_busy_timeout=30000") + } + if !hasQuery && runtime.GOOS != "windows" { // To avoid `The process cannot access the file because it is being used by another process` on Windows - connStr += "?cache=shared" + defaults = append(defaults, "cache=shared") + } + parts := defaults + if hasQuery { + parts = append(parts, query) + } + if len(parts) > 0 { + connStr += "?" + strings.Join(parts, "&") } db, err := gorm.Open(sqlite.Open(connStr), getGormConfig()) @@ -3402,7 +3418,7 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) } func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error { - timeoutCtx, cancel := context.WithTimeout(context.Background(), s.transactionTimeout) + timeoutCtx, cancel := context.WithTimeout(ctx, s.transactionTimeout) defer cancel() startTime := time.Now() diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 2819265c3..7515add62 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -4592,3 +4592,55 @@ func TestSqlStore_DeleteZoneDNSRecords(t *testing.T) { require.NoError(t, err) assert.Equal(t, 0, len(remainingRecords)) } + +// TestNewSqliteStore_BusyTimeoutApplied opens a fresh SQLite store and verifies +// that the _busy_timeout DSN parameter took effect at the driver level. Without +// this, lock contention on the single SQLite connection waits indefinitely on +// the Go side and can be hidden behind the 5-minute transactionTimeout. +func TestNewSqliteStore_BusyTimeoutApplied(t *testing.T) { + dir := t.TempDir() + store, err := NewSqliteStore(context.Background(), dir, nil, true) + require.NoError(t, err) + t.Cleanup(func() { + _ = store.Close(context.Background()) + }) + + sqlDB, err := store.db.DB() + require.NoError(t, err) + row := sqlDB.QueryRow("PRAGMA busy_timeout") + var busyTimeout int + require.NoError(t, row.Scan(&busyTimeout)) + assert.Equal(t, 30000, busyTimeout, "SQLite busy_timeout must be set via DSN so it survives connection recycling") +} + +// TestNewSqliteStore_BusyTimeoutRespectsUserOverride confirms that an operator +// passing _busy_timeout or its mattn alias _timeout via NB_STORE_ENGINE_SQLITE_FILE +// wins over our 30s default. This guards the DSN merge logic in NewSqliteStore. +func TestNewSqliteStore_BusyTimeoutRespectsUserOverride(t *testing.T) { + cases := []struct { + name string + envFile string + expected int + }{ + {name: "explicit _busy_timeout wins", envFile: "store.db?_busy_timeout=5000", expected: 5000}, + {name: "alias _timeout wins", envFile: "store.db?_timeout=7000", expected: 7000}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Setenv("NB_STORE_ENGINE_SQLITE_FILE", tc.envFile) + dir := t.TempDir() + store, err := NewSqliteStore(context.Background(), dir, nil, true) + require.NoError(t, err) + t.Cleanup(func() { + _ = store.Close(context.Background()) + }) + + sqlDB, err := store.db.DB() + require.NoError(t, err) + row := sqlDB.QueryRow("PRAGMA busy_timeout") + var busyTimeout int + require.NoError(t, row.Scan(&busyTimeout)) + assert.Equal(t, tc.expected, busyTimeout) + }) + } +} From ea9fab4396fc5513f7c62e3465dd361dc8bb9e91 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 14 May 2026 23:05:33 +0900 Subject: [PATCH 08/17] [management] Allocate and preserve IPv6 overlay addresses for embedded proxy peers (#6132) --- management/server/account.go | 12 +++++++++ management/server/peer.go | 17 +++++++----- management/server/types/account.go | 30 ++++++++++----------- management/server/types/group.go | 5 +++- management/server/types/ipv6_groups_test.go | 30 +++++++++++++++++++++ 5 files changed, 70 insertions(+), 24 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 364c0c37b..77a46a069 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2487,6 +2487,18 @@ func (am *DefaultAccountManager) buildIPv6AllowedPeers(ctx context.Context, tran allowedPeers[peerID] = struct{}{} } } + + // Embedded proxy peers sit outside regular group membership but must + // participate in any v6-enabled overlay to reach v6-only peers. + peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") + if err != nil { + return nil, fmt.Errorf("get peers: %w", err) + } + for _, p := range peers { + if p.ProxyMeta.Embedded { + allowedPeers[p.ID] = struct{}{} + } + } return allowedPeers, nil } diff --git a/management/server/peer.go b/management/server/peer.go index 8a39fbbb8..c3b130ba2 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -762,16 +762,19 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe newPeer.IP = freeIP if len(settings.IPv6EnabledGroups) > 0 && network.NetV6.IP != nil { - var allGroupID string - if !peer.ProxyMeta.Embedded { - allGroup, err := am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, "All") - if err != nil { - log.WithContext(ctx).Debugf("get All group for IPv6 allocation: %v", err) - } else { + // Embedded proxy peers are not group members but participate in any + // IPv6-enabled overlay so reverse-proxy traffic reaches v6-only peers. + allocate := peer.ProxyMeta.Embedded + if !allocate { + var allGroupID string + if allGroup, err := am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, types.GroupAllName); err == nil { allGroupID = allGroup.ID + } else { + log.WithContext(ctx).Debugf("get All group for IPv6 allocation: %v", err) } + allocate = peerWillHaveIPv6(settings, peerAddConfig.GroupsToAdd, allGroupID) } - if peerWillHaveIPv6(settings, peerAddConfig.GroupsToAdd, allGroupID) { + if allocate { v6Prefix, err := netip.ParsePrefix(network.NetV6.String()) if err != nil { return nil, nil, nil, fmt.Errorf("parse IPv6 prefix: %w", err) diff --git a/management/server/types/account.go b/management/server/types/account.go index 49600163a..870333a60 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -598,28 +598,21 @@ func (a *Account) GetPeerGroups(peerID string) LookupMap { return groupList } -// PeerIPv6Allowed reports whether the given peer is in any of the account's IPv6 enabled groups. +// PeerIPv6Allowed reports whether the given peer participates in the IPv6 overlay. // Returns false if IPv6 is disabled or no groups are configured. func (a *Account) PeerIPv6Allowed(peerID string) bool { - if len(a.Settings.IPv6EnabledGroups) == 0 { - return false - } - - for _, groupID := range a.Settings.IPv6EnabledGroups { - group, ok := a.Groups[groupID] - if !ok { - continue - } - if slices.Contains(group.Peers, peerID) { - return true - } - } - return false + _, ok := a.peerIPv6AllowedSet()[peerID] + return ok } -// peerIPv6AllowedSet returns a set of peer IDs that belong to any IPv6-enabled group. +// peerIPv6AllowedSet returns the set of peer IDs that participate in the IPv6 overlay: +// members of any IPv6-enabled group, plus every embedded proxy peer (which sit outside +// regular group membership but must reach v6-enabled peers). func (a *Account) peerIPv6AllowedSet() map[string]struct{} { result := make(map[string]struct{}) + if len(a.Settings.IPv6EnabledGroups) == 0 { + return result + } for _, groupID := range a.Settings.IPv6EnabledGroups { group, ok := a.Groups[groupID] if !ok { @@ -629,6 +622,11 @@ func (a *Account) peerIPv6AllowedSet() map[string]struct{} { result[peerID] = struct{}{} } } + for id, p := range a.Peers { + if p != nil && p.ProxyMeta.Embedded { + result[id] = struct{}{} + } + } return result } diff --git a/management/server/types/group.go b/management/server/types/group.go index 00fdf7a69..b4f50080a 100644 --- a/management/server/types/group.go +++ b/management/server/types/group.go @@ -92,9 +92,12 @@ func (g *Group) HasPeers() bool { return len(g.Peers) > 0 } +// GroupAllName is the reserved name of the default group that contains every peer in an account. +const GroupAllName = "All" + // IsGroupAll checks if the group is a default "All" group. func (g *Group) IsGroupAll() bool { - return g.Name == "All" + return g.Name == GroupAllName } // AddPeer adds peerID to Peers if not present, returning true if added. diff --git a/management/server/types/ipv6_groups_test.go b/management/server/types/ipv6_groups_test.go index 5151e1b1f..766a9c92c 100644 --- a/management/server/types/ipv6_groups_test.go +++ b/management/server/types/ipv6_groups_test.go @@ -232,3 +232,33 @@ func TestIPv6RecalculationOnGroupChange(t *testing.T) { assert.True(t, account.PeerIPv6Allowed("peer3"), "peer3 now in infra") }) } + +func TestPeerIPv6AllowedEmbeddedProxy(t *testing.T) { + account := &Account{ + Peers: map[string]*nbpeer.Peer{ + "peer1": {ID: "peer1"}, + "proxy": {ID: "proxy", ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: "netbird.test"}}, + }, + Groups: map[string]*Group{ + "group-devs": {ID: "group-devs", Peers: []string{"peer1"}}, + }, + Settings: &Settings{}, + } + + t.Run("embedded proxy allowed when any v6 group exists, without group membership", func(t *testing.T) { + account.Settings.IPv6EnabledGroups = []string{"group-devs"} + assert.True(t, account.PeerIPv6Allowed("proxy"), "embedded proxy participates in v6 overlay") + assert.True(t, account.PeerIPv6Allowed("peer1"), "regular peer in enabled group still allowed") + }) + + t.Run("embedded proxy denied when no v6 group enabled", func(t *testing.T) { + account.Settings.IPv6EnabledGroups = nil + assert.False(t, account.PeerIPv6Allowed("proxy"), "v6 disabled account-wide denies embedded proxies too") + }) + + t.Run("non-embedded peer outside any enabled group is not pulled in", func(t *testing.T) { + account.Settings.IPv6EnabledGroups = []string{"group-devs"} + account.Peers["lonely"] = &nbpeer.Peer{ID: "lonely"} + assert.False(t, account.PeerIPv6Allowed("lonely"), "embedded-proxy bypass must not leak to regular peers") + }) +} From 3f914090cbb345707a88b5edb20d9c1351873b4c Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 14 May 2026 23:22:53 +0900 Subject: [PATCH 09/17] [client] Bracket IPv6 in embed listeners, expand debug bundle (#6134) --- client/embed/embed.go | 4 +- client/internal/debug/debug.go | 36 +++-- client/internal/debug/debug_linux.go | 195 +++++++++++++++++------- client/internal/debug/debug_nonlinux.go | 5 + 4 files changed, 178 insertions(+), 62 deletions(-) diff --git a/client/embed/embed.go b/client/embed/embed.go index 4b9445b97..8b669e547 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -336,7 +336,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) { if err != nil { return nil, fmt.Errorf("split host port: %w", err) } - listenAddr := fmt.Sprintf("%s:%s", addr, port) + listenAddr := net.JoinHostPort(addr.String(), port) tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr) if err != nil { @@ -357,7 +357,7 @@ func (c *Client) ListenUDP(address string) (net.PacketConn, error) { if err != nil { return nil, fmt.Errorf("split host port: %w", err) } - listenAddr := fmt.Sprintf("%s:%s", addr, port) + listenAddr := net.JoinHostPort(addr.String(), port) udpAddr, err := net.ResolveUDPAddr("udp", listenAddr) if err != nil { diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 9c50f02b3..ebaf71b21 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -45,8 +45,11 @@ netbird.out: Most recent, anonymized stdout log file of the NetBird client. routes.txt: Detailed system routing table in tabular format including destination, gateway, interface, metrics, and protocol information, if --system-info flag was provided. interfaces.txt: Anonymized network interface information, if --system-info flag was provided. ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided. -iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided. -nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided. +iptables.txt: Anonymized iptables (IPv4) rules with packet counters, if --system-info flag was provided. +ip6tables.txt: Anonymized ip6tables (IPv6) rules with packet counters, if --system-info flag was provided. +ipset.txt: Anonymized ipset list output, if --system-info flag was provided. +nftables.txt: Anonymized nftables rules with packet counters across all families (ip, ip6, inet, etc.), if --system-info flag was provided. +sysctls.txt: Forwarding, reverse-path filter, source-validation, and conntrack accounting sysctl values that the NetBird client may read or modify, if --system-info flag was provided (Linux only). resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided. scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided. resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. @@ -165,22 +168,33 @@ The config.txt file contains anonymized configuration information of the NetBird Other non-sensitive configuration options are included without anonymization. Firewall Rules (Linux only) -The bundle includes two separate firewall rule files: +The bundle includes the following firewall-related files: iptables.txt: -- Complete iptables ruleset with packet counters using 'iptables -v -n -L' +- IPv4 iptables ruleset with packet counters using 'iptables-save' and 'iptables -v -n -L' - Includes all tables (filter, nat, mangle, raw, security) - Shows packet and byte counters for each rule - All IP addresses are anonymized - Chain names, table names, and other non-sensitive information remain unchanged +ip6tables.txt: +- IPv6 ip6tables ruleset with packet counters using 'ip6tables-save' and 'ip6tables -v -n -L' +- Same table coverage and anonymization as iptables.txt +- Omitted when ip6tables is not installed or no IPv6 rules are present + +ipset.txt: +- Output of 'ipset list' (family-agnostic) +- IP addresses are anonymized; set names and types remain unchanged + nftables.txt: -- Complete nftables ruleset obtained via 'nft -a list ruleset' +- Complete nftables ruleset across all families (ip, ip6, inet, arp, bridge, netdev) via 'nft -a list ruleset' - Includes rule handle numbers and packet counters -- All tables, chains, and rules are included -- Shows packet and byte counters for each rule -- All IP addresses are anonymized -- Chain names, table names, and other non-sensitive information remain unchanged +- All IP addresses are anonymized; chain/table names remain unchanged + +sysctls.txt: +- Forwarding (IPv4 + IPv6, global and per-interface), reverse-path filter, source-validation, conntrack accounting, and TCP-related sysctls that netbird may read or modify +- Per-interface keys are enumerated from /proc/sys/net/ipv{4,6}/conf +- Interface names anonymized when --anonymize is set IP Rules (Linux only) The ip_rules.txt file contains detailed IP routing rule information: @@ -412,6 +426,10 @@ func (g *BundleGenerator) addSystemInfo() { log.Errorf("failed to add firewall rules to debug bundle: %v", err) } + if err := g.addSysctls(); err != nil { + log.Errorf("failed to add sysctls to debug bundle: %v", err) + } + if err := g.addDNSInfo(); err != nil { log.Errorf("failed to add DNS info to debug bundle: %v", err) } diff --git a/client/internal/debug/debug_linux.go b/client/internal/debug/debug_linux.go index aedf88b79..40d864eda 100644 --- a/client/internal/debug/debug_linux.go +++ b/client/internal/debug/debug_linux.go @@ -124,15 +124,18 @@ func getSystemdLogs(serviceName string) (string, error) { // addFirewallRules collects and adds firewall rules to the archive func (g *BundleGenerator) addFirewallRules() error { log.Info("Collecting firewall rules") - iptablesRules, err := collectIPTablesRules() + g.addIPTablesRulesToBundle("iptables-save", "iptables", "iptables.txt") + g.addIPTablesRulesToBundle("ip6tables-save", "ip6tables", "ip6tables.txt") + + ipsetOutput, err := collectIPSets() if err != nil { - log.Warnf("Failed to collect iptables rules: %v", err) + log.Warnf("Failed to collect ipset information: %v", err) } else { if g.anonymize { - iptablesRules = g.anonymizer.AnonymizeString(iptablesRules) + ipsetOutput = g.anonymizer.AnonymizeString(ipsetOutput) } - if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil { - log.Warnf("Failed to add iptables rules to bundle: %v", err) + if err := g.addFileToZip(strings.NewReader(ipsetOutput), "ipset.txt"); err != nil { + log.Warnf("Failed to add ipset output to bundle: %v", err) } } @@ -151,44 +154,65 @@ func (g *BundleGenerator) addFirewallRules() error { return nil } -// collectIPTablesRules collects rules using both iptables-save and verbose listing -func collectIPTablesRules() (string, error) { - var builder strings.Builder - - saveOutput, err := collectIPTablesSave() +// addIPTablesRulesToBundle collects iptables/ip6tables rules and writes them to the bundle. +func (g *BundleGenerator) addIPTablesRulesToBundle(saveBin, listBin, filename string) { + rules, err := collectIPTablesRules(saveBin, listBin) if err != nil { - log.Warnf("Failed to collect iptables rules using iptables-save: %v", err) - } else { - builder.WriteString("=== iptables-save output ===\n") + log.Warnf("Failed to collect %s rules: %v", listBin, err) + return + } + if g.anonymize { + rules = g.anonymizer.AnonymizeString(rules) + } + if err := g.addFileToZip(strings.NewReader(rules), filename); err != nil { + log.Warnf("Failed to add %s rules to bundle: %v", listBin, err) + } +} + +// collectIPTablesRules collects rules using both and verbose listing via . +// Returns an error when neither command produced any output (e.g. the binary is missing), +// so the caller can skip writing an empty file. +func collectIPTablesRules(saveBin, listBin string) (string, error) { + var builder strings.Builder + var collected bool + var firstErr error + + saveOutput, err := runCommand(saveBin) + switch { + case err != nil: + firstErr = err + log.Warnf("Failed to collect %s output: %v", saveBin, err) + case strings.TrimSpace(saveOutput) == "": + log.Debugf("%s produced no output, skipping", saveBin) + default: + builder.WriteString(fmt.Sprintf("=== %s output ===\n", saveBin)) builder.WriteString(saveOutput) builder.WriteString("\n") + collected = true } - ipsetOutput, err := collectIPSets() - if err != nil { - log.Warnf("Failed to collect ipset information: %v", err) - } else { - builder.WriteString("=== ipset list output ===\n") - builder.WriteString(ipsetOutput) - builder.WriteString("\n") - } - - builder.WriteString("=== iptables -v -n -L output ===\n") + listHeader := fmt.Sprintf("=== %s -v -n -L output ===\n", listBin) + builder.WriteString(listHeader) tables := []string{"filter", "nat", "mangle", "raw", "security"} - for _, table := range tables { - builder.WriteString(fmt.Sprintf("*%s\n", table)) - - stats, err := getTableStatistics(table) + stats, err := runCommand(listBin, "-v", "-n", "-L", "-t", table) if err != nil { - log.Warnf("Failed to get statistics for table %s: %v", table, err) + if firstErr == nil { + firstErr = err + } + log.Warnf("Failed to get %s statistics for table %s: %v", listBin, table, err) continue } + builder.WriteString(fmt.Sprintf("*%s\n", table)) builder.WriteString(stats) builder.WriteString("\n") + collected = true } + if !collected { + return "", fmt.Errorf("collect %s rules: %w", listBin, firstErr) + } return builder.String(), nil } @@ -214,34 +238,15 @@ func collectIPSets() (string, error) { return ipsets, nil } -// collectIPTablesSave uses iptables-save to get rule definitions -func collectIPTablesSave() (string, error) { - cmd := exec.Command("iptables-save") +// runCommand executes a command and returns its stdout, wrapping stderr in the error on failure. +func runCommand(name string, args ...string) (string, error) { + cmd := exec.Command(name, args...) var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr if err := cmd.Run(); err != nil { - return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String()) - } - - rules := stdout.String() - if strings.TrimSpace(rules) == "" { - return "", fmt.Errorf("no iptables rules found") - } - - return rules, nil -} - -// getTableStatistics gets verbose statistics for an entire table using iptables command -func getTableStatistics(table string) (string, error) { - cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table) - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - if err := cmd.Run(); err != nil { - return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String()) + return "", fmt.Errorf("execute %s: %w (stderr: %s)", name, err, stderr.String()) } return stdout.String(), nil @@ -804,3 +809,91 @@ func formatSetKeyType(keyType nftables.SetDatatype) string { return fmt.Sprintf("type-%v", keyType) } } + +// addSysctls collects forwarding and netbird-managed sysctl values and writes them to the bundle. +func (g *BundleGenerator) addSysctls() error { + log.Info("Collecting sysctls") + content := collectSysctls() + if g.anonymize { + content = g.anonymizer.AnonymizeString(content) + } + if err := g.addFileToZip(strings.NewReader(content), "sysctls.txt"); err != nil { + return fmt.Errorf("add sysctls to bundle: %w", err) + } + return nil +} + +// collectSysctls reads every sysctl that the netbird client may modify, plus +// global IPv4/IPv6 forwarding, and returns a formatted dump grouped by topic. +// Per-interface values are enumerated by listing /proc/sys/net/ipv{4,6}/conf. +func collectSysctls() string { + var builder strings.Builder + + writeSysctlGroup(&builder, "forwarding", []string{ + "net.ipv4.ip_forward", + "net.ipv6.conf.all.forwarding", + "net.ipv6.conf.default.forwarding", + }) + writeSysctlGroup(&builder, "ipv4 per-interface forwarding", listInterfaceSysctls("ipv4", "forwarding")) + writeSysctlGroup(&builder, "ipv6 per-interface forwarding", listInterfaceSysctls("ipv6", "forwarding")) + writeSysctlGroup(&builder, "rp_filter", append( + []string{"net.ipv4.conf.all.rp_filter", "net.ipv4.conf.default.rp_filter"}, + listInterfaceSysctls("ipv4", "rp_filter")..., + )) + writeSysctlGroup(&builder, "src_valid_mark", append( + []string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"}, + listInterfaceSysctls("ipv4", "src_valid_mark")..., + )) + writeSysctlGroup(&builder, "conntrack", []string{ + "net.netfilter.nf_conntrack_acct", + "net.netfilter.nf_conntrack_tcp_loose", + }) + writeSysctlGroup(&builder, "tcp", []string{ + "net.ipv4.tcp_tw_reuse", + }) + + return builder.String() +} + +func writeSysctlGroup(builder *strings.Builder, title string, keys []string) { + builder.WriteString(fmt.Sprintf("=== %s ===\n", title)) + for _, key := range keys { + value, err := readSysctl(key) + if err != nil { + builder.WriteString(fmt.Sprintf("%s = \n", key, err)) + continue + } + builder.WriteString(fmt.Sprintf("%s = %s\n", key, value)) + } + builder.WriteString("\n") +} + +// listInterfaceSysctls returns net.ipvX.conf.. keys for every +// interface present in /proc/sys/net/ipvX/conf, skipping "all" and "default" +// (callers add those explicitly so they appear first). +func listInterfaceSysctls(family, leaf string) []string { + dir := fmt.Sprintf("/proc/sys/net/%s/conf", family) + entries, err := os.ReadDir(dir) + if err != nil { + return nil + } + var keys []string + for _, e := range entries { + name := e.Name() + if name == "all" || name == "default" { + continue + } + keys = append(keys, fmt.Sprintf("net.%s.conf.%s.%s", family, name, leaf)) + } + sort.Strings(keys) + return keys +} + +func readSysctl(key string) (string, error) { + path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/")) + value, err := os.ReadFile(path) + if err != nil { + return "", err + } + return strings.TrimSpace(string(value)), nil +} diff --git a/client/internal/debug/debug_nonlinux.go b/client/internal/debug/debug_nonlinux.go index ace53bd94..878fee40f 100644 --- a/client/internal/debug/debug_nonlinux.go +++ b/client/internal/debug/debug_nonlinux.go @@ -17,3 +17,8 @@ func (g *BundleGenerator) addIPRules() error { // IP rules are only supported on Linux return nil } + +func (g *BundleGenerator) addSysctls() error { + // Sysctl collection is only supported on Linux + return nil +} From 07e5450117dd0451aaeefc18729a822115587e69 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 14 May 2026 23:42:40 +0900 Subject: [PATCH 10/17] [management] Bracket IPv6 reverse-proxy target hosts when building URL Host field (#6141) --- .../modules/reverseproxy/service/service.go | 18 ++++- .../reverseproxy/service/service_test.go | 77 +++++++++++++++++++ 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go index 769e037bc..166a66a5f 100644 --- a/management/internals/modules/reverseproxy/service/service.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -381,13 +381,14 @@ func (s *Service) buildPathMappings() []*proto.PathMapping { } // HTTP/HTTPS: build full URL + hostNoBrackets := strings.TrimSuffix(strings.TrimPrefix(target.Host, "["), "]") targetURL := url.URL{ Scheme: target.Protocol, - Host: target.Host, + Host: bracketIPv6Host(hostNoBrackets), Path: "/", } if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) { - targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.FormatUint(uint64(target.Port), 10)) + targetURL.Host = net.JoinHostPort(hostNoBrackets, strconv.FormatUint(uint64(target.Port), 10)) } path := "/" @@ -405,6 +406,19 @@ func (s *Service) buildPathMappings() []*proto.PathMapping { return pathMappings } +// bracketIPv6Host wraps host in square brackets when it is an IPv6 literal, as +// required for the Host field of net/url.URL (RFC 3986 §3.2.2). v4-mapped IPv6 +// addresses are bracketed too since their textual form contains colons. +func bracketIPv6Host(host string) string { + if strings.HasPrefix(host, "[") { + return host + } + if addr, err := netip.ParseAddr(host); err == nil && addr.Is6() { + return "[" + host + "]" + } + return host +} + func operationToProtoType(op Operation) proto.ProxyMappingUpdateType { switch op { case Create: diff --git a/management/internals/modules/reverseproxy/service/service_test.go b/management/internals/modules/reverseproxy/service/service_test.go index ff54cb79f..f1349ff65 100644 --- a/management/internals/modules/reverseproxy/service/service_test.go +++ b/management/internals/modules/reverseproxy/service/service_test.go @@ -351,6 +351,83 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) { port: 80, wantTarget: "https://10.0.0.1:80/", }, + { + name: "domain host without port is unchanged", + protocol: "http", + host: "example.com", + port: 0, + wantTarget: "http://example.com/", + }, + { + name: "domain host with non-default port is unchanged", + protocol: "http", + host: "example.com", + port: 8080, + wantTarget: "http://example.com:8080/", + }, + { + name: "ipv6 host without port is bracketed", + protocol: "http", + host: "fb00:cafe:1::3", + port: 0, + wantTarget: "http://[fb00:cafe:1::3]/", + }, + { + name: "ipv6 host with default port omits port and brackets host", + protocol: "http", + host: "fb00:cafe:1::3", + port: 80, + wantTarget: "http://[fb00:cafe:1::3]/", + }, + { + name: "ipv6 host with non-default port is bracketed", + protocol: "http", + host: "fb00:cafe:1::3", + port: 8080, + wantTarget: "http://[fb00:cafe:1::3]:8080/", + }, + { + name: "ipv6 loopback without port is bracketed", + protocol: "http", + host: "::1", + port: 0, + wantTarget: "http://[::1]/", + }, + { + name: "ipv6 host with 5-digit port is bracketed", + protocol: "http", + host: "fb00:cafe::1", + port: 18080, + wantTarget: "http://[fb00:cafe::1]:18080/", + }, + { + name: "pre-bracketed ipv6 without port stays single-bracketed", + protocol: "http", + host: "[fb00:cafe::1]", + port: 0, + wantTarget: "http://[fb00:cafe::1]/", + }, + { + name: "pre-bracketed ipv6 with port is not double-bracketed", + protocol: "http", + host: "[fb00:cafe::1]", + port: 8080, + wantTarget: "http://[fb00:cafe::1]:8080/", + }, + { + name: "v4-mapped ipv6 host without port is bracketed", + protocol: "http", + host: "::ffff:10.0.0.1", + port: 0, + wantTarget: "http://[::ffff:10.0.0.1]/", + }, + { + name: "full-form 8-group ipv6 without port is bracketed", + protocol: "http", + host: "fb00:cafe:1:0:0:0:0:3", + port: 0, + wantTarget: "http://[fb00:cafe:1:0:0:0:0:3]/", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From 2ccae7ec479c106efb6d7a7edff4bb55affb2aa4 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 15 May 2026 23:58:47 +0900 Subject: [PATCH 11/17] [client] Mirror v4 exit selection onto v6 pair and honour SkipAutoApply per route (#6150) --- client/internal/routemanager/manager.go | 5 +- .../internal/routeselector/routeselector.go | 84 ++++++----- .../routeselector/routeselector_test.go | 131 ++++++++++++++++++ client/ui/network.go | 10 +- 4 files changed, 197 insertions(+), 33 deletions(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index e5d9363ca..907f1f592 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -704,7 +704,10 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI } func (m *DefaultManager) isExitNodeRoute(routes []*route.Route) bool { - return len(routes) > 0 && routes[0].Network.String() == vars.ExitNodeCIDR + if len(routes) == 0 { + return false + } + return route.IsV4DefaultRoute(routes[0].Network) || route.IsV6DefaultRoute(routes[0].Network) } func (m *DefaultManager) categorizeUserSelection(netID route.NetID, info *exitNodeInfo) { diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 30afc013b..2ddc24bf2 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "slices" + "strings" "sync" "github.com/hashicorp/go-multierror" @@ -12,10 +13,6 @@ import ( "github.com/netbirdio/netbird/route" ) -const ( - exitNodeCIDR = "0.0.0.0/0" -) - type RouteSelector struct { mu sync.RWMutex deselectedRoutes map[route.NetID]struct{} @@ -124,13 +121,7 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { rs.mu.RLock() defer rs.mu.RUnlock() - if rs.deselectAll { - return false - } - - _, deselected := rs.deselectedRoutes[routeID] - isSelected := !deselected - return isSelected + return rs.isSelectedLocked(routeID) } // FilterSelected removes unselected routes from the provided map. @@ -144,23 +135,22 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { filtered := route.HAMap{} for id, rt := range routes { - netID := id.NetID() - _, deselected := rs.deselectedRoutes[netID] - if !deselected { + if !rs.isDeselectedLocked(id.NetID()) { filtered[id] = rt } } return filtered } -// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this specific route +// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route. +// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of +// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing +// transparently apply to the synthesized v6 entry. func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool { rs.mu.RLock() defer rs.mu.RUnlock() - _, selected := rs.selectedRoutes[routeID] - _, deselected := rs.deselectedRoutes[routeID] - return selected || deselected + return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID)) } func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap { @@ -174,7 +164,7 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap filtered := make(route.HAMap, len(routes)) for id, rt := range routes { netID := id.NetID() - if rs.isDeselected(netID) { + if rs.isDeselectedLocked(netID) { continue } @@ -189,13 +179,48 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap return filtered } -func (rs *RouteSelector) isDeselected(netID route.NetID) bool { +// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit +// state of its own, so selections made on the v4 entry govern the v6 entry automatically. +// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route +// would make it inherit unrelated v4 state. Must be called with rs.mu held. +func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID { + name := string(id) + if !strings.HasSuffix(name, route.V6ExitSuffix) { + return id + } + if _, ok := rs.selectedRoutes[id]; ok { + return id + } + if _, ok := rs.deselectedRoutes[id]; ok { + return id + } + return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix)) +} + +func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool { + if rs.deselectAll { + return false + } + _, deselected := rs.deselectedRoutes[routeID] + return !deselected +} + +func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool { + if rs.deselectAll { + return true + } _, deselected := rs.deselectedRoutes[netID] - return deselected || rs.deselectAll + return deselected +} + +func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool { + _, selected := rs.selectedRoutes[routeID] + _, deselected := rs.deselectedRoutes[routeID] + return selected || deselected } func isExitNode(rt []*route.Route) bool { - return len(rt) > 0 && rt[0].Network.String() == exitNodeCIDR + return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network)) } func (rs *RouteSelector) applyExitNodeFilter( @@ -204,26 +229,23 @@ func (rs *RouteSelector) applyExitNodeFilter( rt []*route.Route, out route.HAMap, ) { - - if rs.hasUserSelections() { - // user made explicit selects/deselects - if rs.IsSelected(netID) { + // Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also + // drops the synthesized v6 entry that lacks its own explicit state. + effective := rs.effectiveNetID(netID) + if rs.hasUserSelectionForRouteLocked(effective) { + if rs.isSelectedLocked(effective) { out[id] = rt } return } - // no explicit selections: only include routes marked !SkipAutoApply (=AutoApply) + // no explicit selection for this route: defer to management's SkipAutoApply flag sel := collectSelected(rt) if len(sel) > 0 { out[id] = sel } } -func (rs *RouteSelector) hasUserSelections() bool { - return len(rs.selectedRoutes) > 0 || len(rs.deselectedRoutes) > 0 -} - func collectSelected(rt []*route.Route) []*route.Route { var sel []*route.Route for _, r := range rt { diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go index 5faea2456..3f0d9f120 100644 --- a/client/internal/routeselector/routeselector_test.go +++ b/client/internal/routeselector/routeselector_test.go @@ -330,6 +330,137 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) { assert.Len(t, filtered, 0) // No routes should be selected } +// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection +// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute +// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its +// v4 base, so legacy persisted selections that predate v6 pairing transparently +// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected) +// stay literal so unrelated routes named "*-v6" don't inherit unrelated state. +func TestRouteSelector_V6ExitPairInherits(t *testing.T) { + all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"} + + t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all)) + + assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection") + + // unrelated v6 with no v4 base touched is unaffected + assert.False(t, rs.HasUserSelectionForRoute("exit2-v6")) + }) + + t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all)) + + // A non-exit route literally named "corp-v6" must not inherit "corp"'s state + // via the mirror; the mirror only applies in exit-node code paths. + assert.False(t, rs.IsSelected("corp")) + assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state") + }) + + t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all)) + require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all)) + + v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")} + v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")} + routes := route.HAMap{ + "exit1|0.0.0.0/0": {v4Route}, + "exit1-v6|::/0": {v6Route}, + } + + filtered := rs.FilterSelectedExitNodes(routes) + assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0")) + assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base") + }) + + t.Run("non-v6-suffix routes unaffected", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all)) + + // A route literally named "exit1-something" must not pair-resolve. + assert.False(t, rs.HasUserSelectionForRoute("exit1-something")) + }) + + t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all)) + + v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")} + v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")} + routes := route.HAMap{ + "exit1|0.0.0.0/0": {v4Route}, + "exit1-v6|::/0": {v6Route}, + } + + filtered := rs.FilterSelectedExitNodes(routes) + assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair") + }) + + t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all)) + + // A non-default-route entry named "corp-v6" is not an exit node and + // must not be skipped because its v4 base "corp" is deselected. + corpV6 := &route.Route{NetID: "corp-v6", Network: netip.MustParsePrefix("10.0.0.0/8")} + routes := route.HAMap{ + "corp-v6|10.0.0.0/8": {corpV6}, + } + + filtered := rs.FilterSelectedExitNodes(routes) + assert.Contains(t, filtered, route.HAUniqueID("corp-v6|10.0.0.0/8"), + "non-exit *-v6 routes must not inherit unrelated v4 state in FilterSelectedExitNodes") + }) +} + +// TestRouteSelector_SkipAutoApplyPerRoute verifies that management's +// SkipAutoApply flag governs each untouched route independently, even when +// the user has explicit selections on other routes. +func TestRouteSelector_SkipAutoApplyPerRoute(t *testing.T) { + autoApplied := &route.Route{ + NetID: "Auto", + Network: netip.MustParsePrefix("0.0.0.0/0"), + SkipAutoApply: false, + } + skipApply := &route.Route{ + NetID: "Skip", + Network: netip.MustParsePrefix("0.0.0.0/0"), + SkipAutoApply: true, + } + routes := route.HAMap{ + "Auto|0.0.0.0/0": {autoApplied}, + "Skip|0.0.0.0/0": {skipApply}, + } + + rs := routeselector.NewRouteSelector() + // User makes an unrelated explicit selection elsewhere. + require.NoError(t, rs.DeselectRoutes([]route.NetID{"Unrelated"}, []route.NetID{"Auto", "Skip", "Unrelated"})) + + filtered := rs.FilterSelectedExitNodes(routes) + assert.Contains(t, filtered, route.HAUniqueID("Auto|0.0.0.0/0"), "AutoApply route should be included") + assert.NotContains(t, filtered, route.HAUniqueID("Skip|0.0.0.0/0"), "SkipAutoApply route should be excluded without explicit user selection") +} + +// TestRouteSelector_V6ExitIsExitNode verifies that ::/0 routes are recognized +// as exit nodes by the selector's filter path. +func TestRouteSelector_V6ExitIsExitNode(t *testing.T) { + v6Exit := &route.Route{ + NetID: "V6Only", + Network: netip.MustParsePrefix("::/0"), + SkipAutoApply: true, + } + routes := route.HAMap{ + "V6Only|::/0": {v6Exit}, + } + + rs := routeselector.NewRouteSelector() + filtered := rs.FilterSelectedExitNodes(routes) + assert.Empty(t, filtered, "::/0 should be treated as an exit node and respect SkipAutoApply") +} + func TestRouteSelector_NewRoutesBehavior(t *testing.T) { initialRoutes := []route.NetID{"route1", "route2", "route3"} newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"} diff --git a/client/ui/network.go b/client/ui/network.go index 1619f78a2..cd5d23558 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -193,7 +193,15 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network { } func isDefaultRoute(routeRange string) bool { - return routeRange == "0.0.0.0/0" || routeRange == "::/0" + // routeRange is the merged display string from the daemon, e.g. "0.0.0.0/0", + // "::/0", or "0.0.0.0/0, ::/0" when a v4 exit node has a paired v6 entry. + for _, part := range strings.Split(routeRange, ",") { + switch strings.TrimSpace(part) { + case "0.0.0.0/0", "::/0": + return true + } + } + return false } func getExitNodeNetworks(routes []*proto.Network) []*proto.Network { From 9ed2e2a5b463077f8abe3e3926695f5dc9411e29 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 16 May 2026 00:07:38 +0900 Subject: [PATCH 12/17] [client] Drop DNS probes for passive health projection (#5971) --- client/internal/connect.go | 2 - client/internal/dns/host.go | 12 + client/internal/dns/host_android.go | 19 +- client/internal/dns/host_ios.go | 9 + client/internal/dns/host_windows.go | 121 ++- client/internal/dns/hosts_dns_holder.go | 1 + client/internal/dns/local/local.go | 2 - client/internal/dns/mock_server.go | 9 +- client/internal/dns/network_manager_unix.go | 211 ++++- client/internal/dns/server.go | 928 ++++++++++++-------- client/internal/dns/server_android.go | 2 +- client/internal/dns/server_test.go | 698 +++++++++++++-- client/internal/dns/systemd_linux.go | 151 +++- client/internal/dns/upstream.go | 683 +++++++------- client/internal/dns/upstream_android.go | 5 +- client/internal/dns/upstream_general.go | 5 +- client/internal/dns/upstream_ios.go | 17 +- client/internal/dns/upstream_test.go | 227 +++-- client/internal/engine.go | 16 +- client/internal/routemanager/manager.go | 34 + client/internal/routemanager/mock.go | 9 + client/ios/NetBirdSDK/client.go | 6 +- 22 files changed, 2294 insertions(+), 873 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index 8c0e9b1ba..ea884818f 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -116,7 +116,6 @@ func (c *ConnectClient) RunOniOS( fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager, - dnsAddresses []netip.AddrPort, stateFilePath string, ) error { // Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension. @@ -126,7 +125,6 @@ func (c *ConnectClient) RunOniOS( FileDescriptor: fileDescriptor, NetworkChangeListener: networkChangeListener, DnsManager: dnsManager, - HostDNSAddresses: dnsAddresses, StateFilePath: stateFilePath, } return c.run(mobileDependency, nil, "") diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index f7dc46a6b..48eacef29 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -16,6 +16,10 @@ type hostManager interface { restoreHostDNS() error supportCustomPort() bool string() string + // getOriginalNameservers returns the OS-side resolvers used as PriorityFallback + // upstreams: pre-takeover snapshots on desktop, the OS-pushed list on Android, + // hardcoded Quad9 on iOS, nil for noop / mock. + getOriginalNameservers() []netip.Addr } type SystemDNSSettings struct { @@ -131,3 +135,11 @@ func (n noopHostConfigurator) supportCustomPort() bool { func (n noopHostConfigurator) string() string { return "noop" } + +func (n noopHostConfigurator) getOriginalNameservers() []netip.Addr { + return nil +} + +func (m *mockHostConfigurator) getOriginalNameservers() []netip.Addr { + return nil +} diff --git a/client/internal/dns/host_android.go b/client/internal/dns/host_android.go index dfa3e5712..48b3e0301 100644 --- a/client/internal/dns/host_android.go +++ b/client/internal/dns/host_android.go @@ -1,14 +1,20 @@ package dns import ( + "net/netip" + "github.com/netbirdio/netbird/client/internal/statemanager" ) +// androidHostManager is a noop on the OS side (Android's VPN service handles +// DNS for us) but tracks the OS-reported resolver list pushed via +// OnUpdatedHostDNSServer so it can serve as the fallback nameserver source. type androidHostManager struct { + holder *hostsDNSHolder } -func newHostManager() (*androidHostManager, error) { - return &androidHostManager{}, nil +func newHostManager(holder *hostsDNSHolder) (*androidHostManager, error) { + return &androidHostManager{holder: holder}, nil } func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error { @@ -26,3 +32,12 @@ func (a androidHostManager) supportCustomPort() bool { func (a androidHostManager) string() string { return "none" } + +func (a androidHostManager) getOriginalNameservers() []netip.Addr { + hosts := a.holder.get() + out := make([]netip.Addr, 0, len(hosts)) + for ap := range hosts { + out = append(out, ap.Addr()) + } + return out +} diff --git a/client/internal/dns/host_ios.go b/client/internal/dns/host_ios.go index 1c0ac63e9..860bb8b50 100644 --- a/client/internal/dns/host_ios.go +++ b/client/internal/dns/host_ios.go @@ -3,6 +3,7 @@ package dns import ( "encoding/json" "fmt" + "net/netip" log "github.com/sirupsen/logrus" @@ -20,6 +21,14 @@ func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) { }, nil } +func (a iosHostManager) getOriginalNameservers() []netip.Addr { + // Quad9 v4+v6: 9.9.9.9, 2620:fe::fe. + return []netip.Addr{ + netip.AddrFrom4([4]byte{9, 9, 9, 9}), + netip.AddrFrom16([16]byte{0x26, 0x20, 0x00, 0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xfe}), + } +} + func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error { jsonData, err := json.Marshal(config) if err != nil { diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 4a8cf8cec..4f6ece532 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -7,6 +7,7 @@ import ( "io" "net/netip" "os/exec" + "slices" "strings" "syscall" "time" @@ -44,9 +45,11 @@ const ( nrptMaxDomainsPerRule = 50 - interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces` - interfaceConfigNameServerKey = "NameServer" - interfaceConfigSearchListKey = "SearchList" + interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces` + interfaceConfigPathV6 = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces` + interfaceConfigNameServerKey = "NameServer" + interfaceConfigDhcpNameSrvKey = "DhcpNameServer" + interfaceConfigSearchListKey = "SearchList" // Network interface DNS registration settings disableDynamicUpdateKey = "DisableDynamicUpdate" @@ -67,10 +70,11 @@ const ( ) type registryConfigurator struct { - guid string - routingAll bool - gpo bool - nrptEntryCount int + guid string + routingAll bool + gpo bool + nrptEntryCount int + origNameservers []netip.Addr } func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { @@ -94,6 +98,17 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { gpo: useGPO, } + origNameservers, err := configurator.captureOriginalNameservers() + switch { + case err != nil: + log.Warnf("capture original nameservers from non-WG adapters: %v", err) + case len(origNameservers) == 0: + log.Warnf("no original nameservers captured from non-WG adapters; DNS fallback will be empty") + default: + log.Debugf("captured %d original nameservers from non-WG adapters: %v", len(origNameservers), origNameservers) + } + configurator.origNameservers = origNameservers + if err := configurator.configureInterface(); err != nil { log.Errorf("failed to configure interface settings: %v", err) } @@ -101,6 +116,98 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { return configurator, nil } +// captureOriginalNameservers reads DNS addresses from every Tcpip(6) interface +// registry key except the WG adapter. v4 and v6 servers live in separate +// hives (Tcpip vs Tcpip6) keyed by the same interface GUID. +func (r *registryConfigurator) captureOriginalNameservers() ([]netip.Addr, error) { + seen := make(map[netip.Addr]struct{}) + var out []netip.Addr + var merr *multierror.Error + for _, root := range []string{interfaceConfigPath, interfaceConfigPathV6} { + addrs, err := r.captureFromTcpipRoot(root) + if err != nil { + merr = multierror.Append(merr, fmt.Errorf("%s: %w", root, err)) + continue + } + for _, addr := range addrs { + if _, dup := seen[addr]; dup { + continue + } + seen[addr] = struct{}{} + out = append(out, addr) + } + } + return out, nberrors.FormatErrorOrNil(merr) +} + +func (r *registryConfigurator) captureFromTcpipRoot(rootPath string) ([]netip.Addr, error) { + root, err := registry.OpenKey(registry.LOCAL_MACHINE, rootPath, registry.READ) + if err != nil { + return nil, fmt.Errorf("open key: %w", err) + } + defer closer(root) + + guids, err := root.ReadSubKeyNames(-1) + if err != nil { + return nil, fmt.Errorf("read subkeys: %w", err) + } + + var out []netip.Addr + for _, guid := range guids { + if strings.EqualFold(guid, r.guid) { + continue + } + out = append(out, readInterfaceNameservers(rootPath, guid)...) + } + return out, nil +} + +func readInterfaceNameservers(rootPath, guid string) []netip.Addr { + keyPath := rootPath + "\\" + guid + k, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE) + if err != nil { + return nil + } + defer closer(k) + + // Static NameServer wins over DhcpNameServer for actual resolution. + for _, name := range []string{interfaceConfigNameServerKey, interfaceConfigDhcpNameSrvKey} { + raw, _, err := k.GetStringValue(name) + if err != nil || raw == "" { + continue + } + if out := parseRegistryNameservers(raw); len(out) > 0 { + return out + } + } + return nil +} + +func parseRegistryNameservers(raw string) []netip.Addr { + var out []netip.Addr + for _, field := range strings.FieldsFunc(raw, func(r rune) bool { return r == ',' || r == ' ' || r == '\t' }) { + addr, err := netip.ParseAddr(strings.TrimSpace(field)) + if err != nil { + continue + } + addr = addr.Unmap() + if !addr.IsValid() || addr.IsUnspecified() { + continue + } + // Drop unzoned link-local: not routable without a scope id. If + // the user wrote "fe80::1%eth0" ParseAddr preserves the zone. + if addr.IsLinkLocalUnicast() && addr.Zone() == "" { + continue + } + out = append(out, addr) + } + return out +} + +func (r *registryConfigurator) getOriginalNameservers() []netip.Addr { + return slices.Clone(r.origNameservers) +} + func (r *registryConfigurator) supportCustomPort() bool { return false } diff --git a/client/internal/dns/hosts_dns_holder.go b/client/internal/dns/hosts_dns_holder.go index 980d917a7..9ecc397be 100644 --- a/client/internal/dns/hosts_dns_holder.go +++ b/client/internal/dns/hosts_dns_holder.go @@ -25,6 +25,7 @@ func (h *hostsDNSHolder) set(list []netip.AddrPort) { h.mutex.Unlock() } +//nolint:unused func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} { h.mutex.RLock() l := h.unprotectedDNSList diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index e9d310f00..4a75a76b6 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -76,8 +76,6 @@ func (d *Resolver) ID() types.HandlerID { return "local-resolver" } -func (d *Resolver) ProbeAvailability(context.Context) {} - // ServeDNS handles a DNS request func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { logger := log.WithFields(log.Fields{ diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 548b1f54f..31fedd9e5 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -9,6 +9,7 @@ import ( dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) @@ -70,10 +71,6 @@ func (m *MockServer) SearchDomains() []string { return make([]string, 0) } -// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface -func (m *MockServer) ProbeAvailability() { -} - func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { if m.UpdateServerConfigFunc != nil { return m.UpdateServerConfigFunc(domains) @@ -85,8 +82,8 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { return nil } -// SetRouteChecker mock implementation of SetRouteChecker from Server interface -func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) { +// SetRouteSources mock implementation of SetRouteSources from Server interface +func (m *MockServer) SetRouteSources(selected, active func() route.HAMap) { // Mock implementation - no-op } diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index 66d82dcd7..3932e78b7 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "net/netip" + "slices" "strings" "time" @@ -32,6 +33,15 @@ const ( networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection" networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply" networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete" + networkManagerDbusDeviceIp4ConfigProperty = networkManagerDbusDeviceInterface + ".Ip4Config" + networkManagerDbusDeviceIp6ConfigProperty = networkManagerDbusDeviceInterface + ".Ip6Config" + networkManagerDbusDeviceIfaceProperty = networkManagerDbusDeviceInterface + ".Interface" + networkManagerDbusGetDevicesMethod = networkManagerDest + ".GetDevices" + networkManagerDbusIp4ConfigInterface = "org.freedesktop.NetworkManager.IP4Config" + networkManagerDbusIp6ConfigInterface = "org.freedesktop.NetworkManager.IP6Config" + networkManagerDbusIp4ConfigNameserverDataProperty = networkManagerDbusIp4ConfigInterface + ".NameserverData" + networkManagerDbusIp4ConfigNameserversProperty = networkManagerDbusIp4ConfigInterface + ".Nameservers" + networkManagerDbusIp6ConfigNameserversProperty = networkManagerDbusIp6ConfigInterface + ".Nameservers" networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0 networkManagerDbusIPv4Key = "ipv4" networkManagerDbusIPv6Key = "ipv6" @@ -51,9 +61,10 @@ var supportedNetworkManagerVersionConstraints = []string{ } type networkManagerDbusConfigurator struct { - dbusLinkObject dbus.ObjectPath - routingAll bool - ifaceName string + dbusLinkObject dbus.ObjectPath + routingAll bool + ifaceName string + origNameservers []netip.Addr } // the types below are based on dbus specification, each field is mapped to a dbus type @@ -92,10 +103,200 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusC log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface) - return &networkManagerDbusConfigurator{ + c := &networkManagerDbusConfigurator{ dbusLinkObject: dbus.ObjectPath(s), ifaceName: wgInterface, - }, nil + } + + origNameservers, err := c.captureOriginalNameservers() + switch { + case err != nil: + log.Warnf("capture original nameservers from NetworkManager: %v", err) + case len(origNameservers) == 0: + log.Warnf("no original nameservers captured from non-WG NetworkManager devices; DNS fallback will be empty") + default: + log.Debugf("captured %d original nameservers from non-WG NetworkManager devices: %v", len(origNameservers), origNameservers) + } + c.origNameservers = origNameservers + return c, nil +} + +// captureOriginalNameservers reads DNS servers from every NM device's +// IP4Config / IP6Config except our WG device. +func (n *networkManagerDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) { + devices, err := networkManagerListDevices() + if err != nil { + return nil, fmt.Errorf("list devices: %w", err) + } + + seen := make(map[netip.Addr]struct{}) + var out []netip.Addr + for _, dev := range devices { + if dev == n.dbusLinkObject { + continue + } + ifaceName := readNetworkManagerDeviceInterface(dev) + for _, addr := range readNetworkManagerDeviceDNS(dev) { + addr = addr.Unmap() + if !addr.IsValid() || addr.IsUnspecified() { + continue + } + // IP6Config.Nameservers is a byte slice without zone info; + // reattach the device's interface name so a captured fe80::… + // stays routable. + if addr.IsLinkLocalUnicast() && ifaceName != "" { + addr = addr.WithZone(ifaceName) + } + if _, dup := seen[addr]; dup { + continue + } + seen[addr] = struct{}{} + out = append(out, addr) + } + } + return out, nil +} + +func readNetworkManagerDeviceInterface(devicePath dbus.ObjectPath) string { + obj, closeConn, err := getDbusObject(networkManagerDest, devicePath) + if err != nil { + return "" + } + defer closeConn() + v, err := obj.GetProperty(networkManagerDbusDeviceIfaceProperty) + if err != nil { + return "" + } + s, _ := v.Value().(string) + return s +} + +func networkManagerListDevices() ([]dbus.ObjectPath, error) { + obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) + if err != nil { + return nil, fmt.Errorf("dbus NetworkManager: %w", err) + } + defer closeConn() + var devs []dbus.ObjectPath + if err := obj.Call(networkManagerDbusGetDevicesMethod, dbusDefaultFlag).Store(&devs); err != nil { + return nil, err + } + return devs, nil +} + +func readNetworkManagerDeviceDNS(devicePath dbus.ObjectPath) []netip.Addr { + obj, closeConn, err := getDbusObject(networkManagerDest, devicePath) + if err != nil { + return nil + } + defer closeConn() + + var out []netip.Addr + if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp4ConfigProperty); path != "" { + out = append(out, readIPv4ConfigDNS(path)...) + } + if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp6ConfigProperty); path != "" { + out = append(out, readIPv6ConfigDNS(path)...) + } + return out +} + +func readNetworkManagerConfigPath(obj dbus.BusObject, property string) dbus.ObjectPath { + v, err := obj.GetProperty(property) + if err != nil { + return "" + } + path, ok := v.Value().(dbus.ObjectPath) + if !ok || path == "/" { + return "" + } + return path +} + +func readIPv4ConfigDNS(path dbus.ObjectPath) []netip.Addr { + obj, closeConn, err := getDbusObject(networkManagerDest, path) + if err != nil { + return nil + } + defer closeConn() + + // NameserverData (NM 1.13+) carries strings; older NMs only expose the + // legacy uint32 Nameservers property. + if out := readIPv4NameserverData(obj); len(out) > 0 { + return out + } + return readIPv4LegacyNameservers(obj) +} + +func readIPv4NameserverData(obj dbus.BusObject) []netip.Addr { + v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserverDataProperty) + if err != nil { + return nil + } + entries, ok := v.Value().([]map[string]dbus.Variant) + if !ok { + return nil + } + var out []netip.Addr + for _, entry := range entries { + addrVar, ok := entry["address"] + if !ok { + continue + } + s, ok := addrVar.Value().(string) + if !ok { + continue + } + if a, err := netip.ParseAddr(s); err == nil { + out = append(out, a) + } + } + return out +} + +func readIPv4LegacyNameservers(obj dbus.BusObject) []netip.Addr { + v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserversProperty) + if err != nil { + return nil + } + raw, ok := v.Value().([]uint32) + if !ok { + return nil + } + out := make([]netip.Addr, 0, len(raw)) + for _, n := range raw { + var b [4]byte + binary.LittleEndian.PutUint32(b[:], n) + out = append(out, netip.AddrFrom4(b)) + } + return out +} + +func readIPv6ConfigDNS(path dbus.ObjectPath) []netip.Addr { + obj, closeConn, err := getDbusObject(networkManagerDest, path) + if err != nil { + return nil + } + defer closeConn() + v, err := obj.GetProperty(networkManagerDbusIp6ConfigNameserversProperty) + if err != nil { + return nil + } + raw, ok := v.Value().([][]byte) + if !ok { + return nil + } + out := make([]netip.Addr, 0, len(raw)) + for _, b := range raw { + if a, ok := netip.AddrFromSlice(b); ok { + out = append(out, a) + } + } + return out +} + +func (n *networkManagerDbusConfigurator) getOriginalNameservers() []netip.Addr { + return slices.Clone(n.origNameservers) } func (n *networkManagerDbusConfigurator) supportCustomPort() bool { diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 6fe2e21b6..e689f3586 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -6,11 +6,10 @@ import ( "fmt" "net/netip" "net/url" - "os" - "runtime" - "strconv" + "slices" "strings" "sync" + "time" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" @@ -25,11 +24,31 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/proto" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) -const envSkipDNSProbe = "NB_SKIP_DNS_PROBE" +const ( + // healthLookback must exceed the upstream query timeout so one + // query per refresh cycle is enough to keep a group marked healthy. + healthLookback = 60 * time.Second + nsGroupHealthRefreshInterval = 10 * time.Second + // defaultWarningDelayBase is the starting grace window before a + // "Nameserver group unreachable" event fires for a group that's + // never been healthy and only has overlay upstreams with no + // Connected peer. Per-server and overridable; see warningDelayFor. + defaultWarningDelayBase = 30 * time.Second + // warningDelayBonusCap caps the route-count bonus added to the + // base grace window. See warningDelayFor. + warningDelayBonusCap = 30 * time.Second +) + +// errNoUsableNameservers signals that a merged-domain group has no usable +// upstream servers. Callers should skip the group without treating it as a +// build failure. +var errNoUsableNameservers = errors.New("no usable nameservers") // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes type ReadyListener interface { @@ -54,10 +73,9 @@ type Server interface { UpdateDNSServer(serial uint64, update nbdns.Config) error OnUpdatedHostDNSServer(addrs []netip.AddrPort) SearchDomains() []string - ProbeAvailability() UpdateServerConfig(domains dnsconfig.ServerDomains) error PopulateManagementDomain(mgmtURL *url.URL) error - SetRouteChecker(func(netip.Addr) bool) + SetRouteSources(selected, active func() route.HAMap) SetFirewall(Firewall) } @@ -66,12 +84,47 @@ type nsGroupsByDomain struct { groups []*nbdns.NameServerGroup } -// hostManagerWithOriginalNS extends the basic hostManager interface -type hostManagerWithOriginalNS interface { - hostManager - getOriginalNameservers() []netip.Addr +// nsGroupID identifies a nameserver group by the tuple (server list, domain +// list) so config updates produce stable IDs across recomputations. +type nsGroupID string + +// nsHealthSnapshot is the input to projectNSGroupHealth, captured under +// s.mux so projection runs lock-free. +type nsHealthSnapshot struct { + groups []*nbdns.NameServerGroup + merged map[netip.AddrPort]UpstreamHealth + selected route.HAMap + active route.HAMap } +// nsGroupProj holds per-group state for the emission rules. +type nsGroupProj struct { + // unhealthySince is the start of the current Unhealthy streak, + // zero when the group is not currently Unhealthy. + unhealthySince time.Time + // everHealthy is sticky: once the group has been Healthy at least + // once this session, subsequent failures skip warningDelay. + everHealthy bool + // warningActive tracks whether we've already published a warning + // for the current streak, so recovery emits iff a warning did. + warningActive bool +} + +// nsGroupVerdict is the outcome of evaluateNSGroupHealth. +type nsGroupVerdict int + +const ( + // nsVerdictUndecided means no upstream has a fresh observation + // (startup before first query, or records aged past healthLookback). + nsVerdictUndecided nsGroupVerdict = iota + // nsVerdictHealthy means at least one upstream's most-recent + // in-lookback observation is a success. + nsVerdictHealthy + // nsVerdictUnhealthy means at least one upstream has a recent + // failure and none has a fresher success. + nsVerdictUnhealthy +) + // DefaultServer dns server object type DefaultServer struct { ctx context.Context @@ -100,26 +153,46 @@ type DefaultServer struct { permanent bool hostsDNSHolder *hostsDNSHolder + // fallbackHandler is the upstream resolver currently registered at + // PriorityFallback. Tracked so registerFallback can Stop() the previous + // instance instead of leaking its context. + fallbackHandler handlerWithStop + // make sense on mobile only searchDomainNotifier *notifier iosDnsManager IosDnsManager statusRecorder *peer.Status stateManager *statemanager.Manager - routeMatch func(netip.Addr) bool + // selectedRoutes returns admin-enabled client routes. + selectedRoutes func() route.HAMap + // activeRoutes returns the subset whose peer is in StatusConnected. + activeRoutes func() route.HAMap - probeMu sync.Mutex - probeCancel context.CancelFunc - probeWg sync.WaitGroup + nsGroups []*nbdns.NameServerGroup + healthProjectMu sync.Mutex + // nsGroupProj is the per-group state used by the emission rules. + // Accessed only under healthProjectMu. + nsGroupProj map[nsGroupID]*nsGroupProj + // warningDelayBase is the base grace window for health projection. + // Set at construction, mutated only by tests. Read by the + // refresher goroutine so never change it while one is running. + warningDelayBase time.Duration + // healthRefresh is buffered=1; writers coalesce, senders never block. + // See refreshHealth for the lock-order rationale. + healthRefresh chan struct{} } type handlerWithStop interface { dns.Handler Stop() - ProbeAvailability(context.Context) ID() types.HandlerID } +type upstreamHealthReporter interface { + UpstreamHealth() map[netip.AddrPort]UpstreamHealth +} + type handlerWrapper struct { domain string handler handlerWithStop @@ -174,7 +247,6 @@ func NewDefaultServerPermanentUpstream( ds.hostsDNSHolder.set(hostsDnsList) ds.permanent = true - ds.addHostRootZone() ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort()) ds.searchDomainNotifier = newNotifier(ds.SearchDomains()) ds.searchDomainNotifier.setListener(listener) @@ -182,21 +254,17 @@ func NewDefaultServerPermanentUpstream( return ds } -// NewDefaultServerIos returns a new dns server. It optimized for ios +// NewDefaultServerIos returns a new dns server. It optimized for ios. func NewDefaultServerIos( ctx context.Context, wgInterface WGIface, iosDnsManager IosDnsManager, - hostsDnsList []netip.AddrPort, statusRecorder *peer.Status, disableSys bool, ) *DefaultServer { - log.Debugf("iOS host dns address list is: %v", hostsDnsList) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys) ds.iosDnsManager = iosDnsManager - ds.hostsDNSHolder.set(hostsDnsList) ds.permanent = true - ds.addHostRootZone() return ds } @@ -230,6 +298,8 @@ func newDefaultServer( hostManager: &noopHostConfigurator{}, mgmtCacheResolver: mgmtCacheResolver, currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied + warningDelayBase: defaultWarningDelayBase, + healthRefresh: make(chan struct{}, 1), } // register with root zone, handler chain takes care of the routing @@ -238,12 +308,26 @@ func newDefaultServer( return defaultServer } -// SetRouteChecker sets the function used by upstream resolvers to determine -// whether an IP is routed through the tunnel. -func (s *DefaultServer) SetRouteChecker(f func(netip.Addr) bool) { +// SetRouteSources wires the route-manager accessors used by health +// projection to classify each upstream for emission timing. +func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) { s.mux.Lock() defer s.mux.Unlock() - s.routeMatch = f + s.selectedRoutes = selected + s.activeRoutes = active + + // Permanent / iOS constructors build the root handler before the + // engine wires route sources, so its selectedRoutes callback would + // otherwise remain nil and overlay upstreams would be classified + // as public. Propagate the new accessors to existing handlers. + type routeSettable interface { + setSelectedRoutes(func() route.HAMap) + } + for _, entry := range s.dnsMuxMap { + if h, ok := entry.handler.(routeSettable); ok { + h.setSelectedRoutes(selected) + } + } } // RegisterHandler registers a handler for the given domains with the given priority. @@ -256,7 +340,6 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler // TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain for _, domain := range domains { - // convert to zone with simple ref counter s.extraDomains[toZone(domain)]++ } if !s.batchMode { @@ -357,6 +440,8 @@ func (s *DefaultServer) Initialize() (err error) { s.stateManager.RegisterState(&ShutdownState{}) + s.startHealthRefresher() + // Keep using noop host manager if dns off requested or running in netstack mode. // Netstack mode currently doesn't have a way to receive DNS requests. // TODO: Use listener on localhost in netstack mode when running as root. @@ -370,6 +455,13 @@ func (s *DefaultServer) Initialize() (err error) { return fmt.Errorf("initialize: %w", err) } s.hostManager = hostManager + // On mobile-permanent setups the seeded host DNS list is the only + // source until the first network-map arrives; register it now so DNS + // works in that window. Desktop host managers register fallback when + // applyConfiguration runs. + if s.permanent { + s.registerFallback() + } return nil } @@ -394,13 +486,7 @@ func (s *DefaultServer) SetFirewall(fw Firewall) { // Stop stops the server func (s *DefaultServer) Stop() { - s.probeMu.Lock() - if s.probeCancel != nil { - s.probeCancel() - } s.ctxCancel() - s.probeMu.Unlock() - s.probeWg.Wait() s.shutdownWg.Wait() s.mux.Lock() @@ -411,6 +497,13 @@ func (s *DefaultServer) Stop() { } clear(s.extraDomains) + + // Clear health projection state so a subsequent Start doesn't + // inherit sticky flags (notably everHealthy) that would bypass + // the grace window during the next peer handshake. + s.healthProjectMu.Lock() + s.nsGroupProj = nil + s.healthProjectMu.Unlock() } func (s *DefaultServer) disableDNS() (retErr error) { @@ -424,10 +517,9 @@ func (s *DefaultServer) disableDNS() (retErr error) { return nil } - // Deregister original nameservers if they were registered as fallback - if srvs, ok := s.hostManager.(hostManagerWithOriginalNS); ok && len(srvs.getOriginalNameservers()) > 0 { - log.Debugf("deregistering original nameservers as fallback handlers") - s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) + if s.fallbackHandler != nil { + log.Debugf("deregistering fallback handlers") + s.clearFallback() } if err := s.hostManager.restoreHostDNS(); err != nil { @@ -441,27 +533,16 @@ func (s *DefaultServer) disableDNS() (retErr error) { return nil } -// OnUpdatedHostDNSServer update the DNS servers addresses for root zones -// It will be applied if the mgm server do not enforce DNS settings for root zone +// OnUpdatedHostDNSServer updates the fallback DNS upstreams. Called by Android +// outside the engine's sync mux when the OS reports a network change, so it +// takes s.mux to serialize against host manager swaps in Initialize/enableDNS. func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []netip.AddrPort) { s.hostsDNSHolder.set(hostsDnsList) - - // Check if there's any root handler - var hasRootHandler bool - for _, handler := range s.dnsMuxMap { - if handler.domain == nbdns.RootZone { - hasRootHandler = true - break - } - } - - if hasRootHandler { - log.Debugf("on new host DNS config but skip to apply it") - return - } - log.Debugf("update host DNS settings: %+v", hostsDnsList) - s.addHostRootZone() + + s.mux.Lock() + defer s.mux.Unlock() + s.registerFallback() } // UpdateDNSServer processes an update received from the management service @@ -520,69 +601,6 @@ func (s *DefaultServer) SearchDomains() []string { return searchDomains } -// ProbeAvailability tests each upstream group's servers for availability -// and deactivates the group if no server responds. -// If a previous probe is still running, it will be cancelled before starting a new one. -func (s *DefaultServer) ProbeAvailability() { - if val := os.Getenv(envSkipDNSProbe); val != "" { - skipProbe, err := strconv.ParseBool(val) - if err != nil { - log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err) - } - if skipProbe { - log.Infof("skipping DNS probe due to %s", envSkipDNSProbe) - return - } - } - - s.probeMu.Lock() - - // don't start probes on a stopped server - if s.ctx.Err() != nil { - s.probeMu.Unlock() - return - } - - // cancel any running probe - if s.probeCancel != nil { - s.probeCancel() - s.probeCancel = nil - } - - // wait for the previous probe goroutines to finish while holding - // the mutex so no other caller can start a new probe concurrently - s.probeWg.Wait() - - // start a new probe - probeCtx, probeCancel := context.WithCancel(s.ctx) - s.probeCancel = probeCancel - - s.probeWg.Add(1) - defer s.probeWg.Done() - - // Snapshot handlers under s.mux to avoid racing with updateMux/dnsMuxMap writers. - s.mux.Lock() - handlers := make([]handlerWithStop, 0, len(s.dnsMuxMap)) - for _, mux := range s.dnsMuxMap { - handlers = append(handlers, mux.handler) - } - s.mux.Unlock() - - var wg sync.WaitGroup - for _, handler := range handlers { - wg.Add(1) - go func(h handlerWithStop) { - defer wg.Done() - h.ProbeAvailability(probeCtx) - }(handler) - } - - s.probeMu.Unlock() - - wg.Wait() - probeCancel() -} - func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { s.mux.Lock() defer s.mux.Unlock() @@ -746,19 +764,17 @@ func (s *DefaultServer) applyHostConfig() { s.currentConfigHash = hash } - s.registerFallback(config) + s.registerFallback() } // registerFallback registers original nameservers as low-priority fallback handlers. -func (s *DefaultServer) registerFallback(config HostDNSConfig) { - hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS) - if !ok { - return - } - - originalNameservers := hostMgrWithNS.getOriginalNameservers() +// Replaces and Stop()s the previously-registered fallback handler so its +// context is released rather than leaked until GC. +func (s *DefaultServer) registerFallback() { + originalNameservers := s.hostManager.getOriginalNameservers() if len(originalNameservers) == 0 { - s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) + log.Debugf("no fallback upstreams to register; clearing PriorityFallback handler") + s.clearFallback() return } @@ -775,21 +791,28 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { log.Errorf("failed to create upstream resolver for original nameservers: %v", err) return } - handler.routeMatch = s.routeMatch + handler.selectedRoutes = s.selectedRoutes + var servers []netip.AddrPort for _, ns := range originalNameservers { - if ns == config.ServerIP { - log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP) - continue - } - - addrPort := netip.AddrPortFrom(ns, DefaultPort) - handler.upstreamServers = append(handler.upstreamServers, addrPort) + servers = append(servers, netip.AddrPortFrom(ns, DefaultPort)) } - handler.deactivate = func(error) { /* always active */ } - handler.reactivate = func() { /* always active */ } + handler.addRace(servers) + prev := s.fallbackHandler + s.fallbackHandler = handler s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback) + if prev != nil { + prev.Stop() + } +} + +func (s *DefaultServer) clearFallback() { + s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) + if s.fallbackHandler != nil { + s.fallbackHandler.Stop() + s.fallbackHandler = nil + } } func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) { @@ -847,100 +870,99 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam groupedNS := groupNSGroupsByDomain(nameServerGroups) for _, domainGroup := range groupedNS { - basePriority := PriorityUpstream + priority := PriorityUpstream if domainGroup.domain == nbdns.RootZone { - basePriority = PriorityDefault + priority = PriorityDefault } - updates, err := s.createHandlersForDomainGroup(domainGroup, basePriority) + update, err := s.buildMergedDomainHandler(domainGroup, priority) if err != nil { + if errors.Is(err, errNoUsableNameservers) { + log.Errorf("no usable nameservers for domain=%s", domainGroup.domain) + continue + } return nil, err } - muxUpdates = append(muxUpdates, updates...) + muxUpdates = append(muxUpdates, *update) } return muxUpdates, nil } -func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomain, basePriority int) ([]handlerWrapper, error) { - var muxUpdates []handlerWrapper +// buildMergedDomainHandler merges every nameserver group that targets the +// same domain into one handler whose inner groups are raced in parallel. +func (s *DefaultServer) buildMergedDomainHandler(domainGroup nsGroupsByDomain, priority int) (*handlerWrapper, error) { + handler, err := newUpstreamResolver( + s.ctx, + s.wgInterface, + s.statusRecorder, + s.hostsDNSHolder, + domain.Domain(domainGroup.domain), + ) + if err != nil { + return nil, fmt.Errorf("create upstream resolver: %v", err) + } + handler.selectedRoutes = s.selectedRoutes - for i, nsGroup := range domainGroup.groups { - // Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts - priority := basePriority - i - - // Check if we're about to overlap with the next priority tier - if s.leaksPriority(domainGroup, basePriority, priority) { - break - } - - log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority) - handler, err := newUpstreamResolver( - s.ctx, - s.wgInterface, - s.statusRecorder, - s.hostsDNSHolder, - domainGroup.domain, - ) - if err != nil { - return nil, fmt.Errorf("create upstream resolver: %v", err) - } - handler.routeMatch = s.routeMatch - - for _, ns := range nsGroup.NameServers { - if ns.NSType != nbdns.UDPNameServerType { - log.Warnf("skipping nameserver %s with type %s, this peer supports only %s", - ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) - continue - } - - if ns.IP == s.service.RuntimeIP() { - log.Warnf("skipping nameserver %s as it matches our DNS server IP, preventing potential loop", ns.IP) - continue - } - - handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort()) - } - - if len(handler.upstreamServers) == 0 { - handler.Stop() - log.Errorf("received a nameserver group with an invalid nameserver list") + for _, nsGroup := range domainGroup.groups { + servers := s.filterNameServers(nsGroup.NameServers) + if len(servers) == 0 { + log.Warnf("nameserver group for domain=%s yielded no usable servers, skipping", domainGroup.domain) continue } - - // when upstream fails to resolve domain several times over all it servers - // it will calls this hook to exclude self from the configuration and - // reapply DNS settings, but it not touch the original configuration and serial number - // because it is temporal deactivation until next try - // - // after some period defined by upstream it tries to reactivate self by calling this hook - // everything we need here is just to re-apply current configuration because it already - // contains this upstream settings (temporal deactivation not removed it) - handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler, priority) - - muxUpdates = append(muxUpdates, handlerWrapper{ - domain: domainGroup.domain, - handler: handler, - priority: priority, - }) + handler.addRace(servers) } - return muxUpdates, nil + if len(handler.upstreamServers) == 0 { + handler.Stop() + return nil, errNoUsableNameservers + } + + log.Debugf("creating merged handler for domain=%s with %d group(s) priority=%d", domainGroup.domain, len(handler.upstreamServers), priority) + + return &handlerWrapper{ + domain: domainGroup.domain, + handler: handler, + priority: priority, + }, nil } -func (s *DefaultServer) leaksPriority(domainGroup nsGroupsByDomain, basePriority int, priority int) bool { - if basePriority == PriorityUpstream && priority <= PriorityDefault { - log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", - domainGroup.domain, PriorityUpstream-PriorityDefault) - return true - } - if basePriority == PriorityDefault && priority <= PriorityFallback { - log.Warnf("too many handlers for domain=%s, would overlap with fallback priority tier (diff=%d). Skipping remaining handlers", - domainGroup.domain, PriorityDefault-PriorityFallback) - return true +func (s *DefaultServer) filterNameServers(nameServers []nbdns.NameServer) []netip.AddrPort { + var out []netip.AddrPort + for _, ns := range nameServers { + if ns.NSType != nbdns.UDPNameServerType { + log.Warnf("skipping nameserver %s with type %s, this peer supports only %s", + ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) + continue + } + if ns.IP == s.service.RuntimeIP() { + log.Warnf("skipping nameserver %s as it matches our DNS server IP, preventing potential loop", ns.IP) + continue + } + out = append(out, ns.AddrPort()) } + return out +} - return false +// usableNameServers returns the subset of nameServers the handler would +// actually query. Matches filterNameServers without the warning logs, so +// it's safe to call on every health-projection tick. +func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []netip.AddrPort { + var runtimeIP netip.Addr + if s.service != nil { + runtimeIP = s.service.RuntimeIP() + } + var out []netip.AddrPort + for _, ns := range nameServers { + if ns.NSType != nbdns.UDPNameServerType { + continue + } + if runtimeIP.IsValid() && ns.IP == runtimeIP { + continue + } + out = append(out, ns.AddrPort()) + } + return out } func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { @@ -951,175 +973,356 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { } muxUpdateMap := make(registeredHandlerMap) - var containsRootUpdate bool for _, update := range muxUpdates { - if update.domain == nbdns.RootZone { - containsRootUpdate = true - } s.registerHandler([]string{update.domain}, update.handler, update.priority) muxUpdateMap[update.handler.ID()] = update } - // If there's no root update and we had a root handler, restore it - if !containsRootUpdate { - for _, existing := range s.dnsMuxMap { - if existing.domain == nbdns.RootZone { - s.addHostRootZone() - break - } - } - } - s.dnsMuxMap = muxUpdateMap } -// upstreamCallbacks returns two functions, the first one is used to deactivate -// the upstream resolver from the configuration, the second one is used to -// reactivate it. Not allowed to call reactivate before deactivate. -func (s *DefaultServer) upstreamCallbacks( - nsGroup *nbdns.NameServerGroup, - handler dns.Handler, - priority int, -) (deactivate func(error), reactivate func()) { - var removeIndex map[string]int - deactivate = func(err error) { - s.mux.Lock() - defer s.mux.Unlock() - - l := log.WithField("nameservers", nsGroup.NameServers) - l.Info("Temporarily deactivating nameservers group due to timeout") - - removeIndex = make(map[string]int) - for _, domain := range nsGroup.Domains { - removeIndex[domain] = -1 - } - if nsGroup.Primary { - removeIndex[nbdns.RootZone] = -1 - s.currentConfig.RouteAll = false - s.deregisterHandler([]string{nbdns.RootZone}, priority) - } - - for i, item := range s.currentConfig.Domains { - if _, found := removeIndex[item.Domain]; found { - s.currentConfig.Domains[i].Disabled = true - s.deregisterHandler([]string{item.Domain}, priority) - removeIndex[item.Domain] = i - } - } - - // Always apply host config when nameserver goes down, regardless of batch mode - s.applyHostConfig() - - go func() { - if err := s.stateManager.PersistState(s.ctx); err != nil { - l.Errorf("Failed to persist dns state: %v", err) - } - }() - - if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { - s.addHostRootZone() - } - - s.updateNSState(nsGroup, err, false) - } - - reactivate = func() { - s.mux.Lock() - defer s.mux.Unlock() - - for domain, i := range removeIndex { - if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain { - continue - } - s.currentConfig.Domains[i].Disabled = false - s.registerHandler([]string{domain}, handler, priority) - } - - l := log.WithField("nameservers", nsGroup.NameServers) - l.Debug("reactivate temporary disabled nameserver group") - - if nsGroup.Primary { - s.currentConfig.RouteAll = true - s.registerHandler([]string{nbdns.RootZone}, handler, priority) - } - - // Always apply host config when nameserver reactivates, regardless of batch mode - s.applyHostConfig() - - s.updateNSState(nsGroup, nil, true) - } - return -} - -func (s *DefaultServer) addHostRootZone() { - hostDNSServers := s.hostsDNSHolder.get() - if len(hostDNSServers) == 0 { - log.Debug("no host DNS servers available, skipping root zone handler creation") - return - } - - handler, err := newUpstreamResolver( - s.ctx, - s.wgInterface, - s.statusRecorder, - s.hostsDNSHolder, - nbdns.RootZone, - ) - if err != nil { - log.Errorf("unable to create a new upstream resolver, error: %v", err) - return - } - handler.routeMatch = s.routeMatch - - handler.upstreamServers = maps.Keys(hostDNSServers) - handler.deactivate = func(error) {} - handler.reactivate = func() {} - - s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) -} - +// updateNSGroupStates records the new group set and pokes the refresher. +// Must hold s.mux; projection runs async (see refreshHealth for why). func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { - var states []peer.NSGroupState + s.nsGroups = groups + select { + case s.healthRefresh <- struct{}{}: + default: + } +} - for _, group := range groups { - var servers []netip.AddrPort - for _, ns := range group.NameServers { - servers = append(servers, ns.AddrPort()) +// refreshHealth runs one projection cycle. Must not be called while +// holding s.mux: the route callbacks re-enter routemanager's lock. +func (s *DefaultServer) refreshHealth() { + s.mux.Lock() + groups := s.nsGroups + merged := s.collectUpstreamHealth() + selFn := s.selectedRoutes + actFn := s.activeRoutes + s.mux.Unlock() + + var selected, active route.HAMap + if selFn != nil { + selected = selFn() + } + if actFn != nil { + active = actFn() + } + + s.projectNSGroupHealth(nsHealthSnapshot{ + groups: groups, + merged: merged, + selected: selected, + active: active, + }) +} + +// projectNSGroupHealth applies the emission rules to the snapshot and +// publishes the resulting NSGroupStates. Serialized by healthProjectMu, +// lock-free wrt s.mux. +// +// Rules: +// - Healthy: emit recovery iff warningActive; set everHealthy. +// - Unhealthy: stamp unhealthySince on streak start; emit warning +// iff any of immediate / everHealthy / elapsed >= effective delay. +// - Undecided: no-op. +// +// "Immediate" means the group has at least one upstream that's public +// or overlay+Connected: no peer-startup race to wait out. +func (s *DefaultServer) projectNSGroupHealth(snap nsHealthSnapshot) { + if s.statusRecorder == nil { + return + } + + s.healthProjectMu.Lock() + defer s.healthProjectMu.Unlock() + + if s.nsGroupProj == nil { + s.nsGroupProj = make(map[nsGroupID]*nsGroupProj) + } + + now := time.Now() + delay := s.warningDelay(haMapRouteCount(snap.selected)) + states := make([]peer.NSGroupState, 0, len(snap.groups)) + seen := make(map[nsGroupID]struct{}, len(snap.groups)) + for _, group := range snap.groups { + servers := s.usableNameServers(group.NameServers) + if len(servers) == 0 { + continue + } + verdict, groupErr := evaluateNSGroupHealth(snap.merged, servers, now) + id := generateGroupKey(group) + seen[id] = struct{}{} + + immediate := s.groupHasImmediateUpstream(servers, snap) + + p, known := s.nsGroupProj[id] + if !known { + p = &nsGroupProj{} + s.nsGroupProj[id] = p } - state := peer.NSGroupState{ - ID: generateGroupKey(group), + enabled := true + switch verdict { + case nsVerdictHealthy: + enabled = s.projectHealthy(p, servers) + case nsVerdictUnhealthy: + enabled = s.projectUnhealthy(p, servers, immediate, now, delay) + case nsVerdictUndecided: + // Stay Available until evidence says otherwise, unless a + // warning is already active for this group. Also clear any + // prior Unhealthy streak so a later Unhealthy verdict starts + // a fresh grace window rather than inheriting a stale one. + p.unhealthySince = time.Time{} + enabled = !p.warningActive + groupErr = nil + } + + states = append(states, peer.NSGroupState{ + ID: string(id), Servers: servers, Domains: group.Domains, - // The probe will determine the state, default enabled - Enabled: true, - Error: nil, - } - states = append(states, state) + Enabled: enabled, + Error: groupErr, + }) } - s.statusRecorder.UpdateDNSStates(states) -} - -func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error, enabled bool) { - states := s.statusRecorder.GetDNSStates() - id := generateGroupKey(nsGroup) - for i, state := range states { - if state.ID == id { - states[i].Enabled = enabled - states[i].Error = err - break + for id := range s.nsGroupProj { + if _, ok := seen[id]; !ok { + delete(s.nsGroupProj, id) } } s.statusRecorder.UpdateDNSStates(states) } -func generateGroupKey(nsGroup *nbdns.NameServerGroup) string { - var servers []string +// projectHealthy records a healthy tick on p and publishes a recovery +// event iff a warning was active for the current streak. Returns the +// Enabled flag to record in NSGroupState. +func (s *DefaultServer) projectHealthy(p *nsGroupProj, servers []netip.AddrPort) bool { + p.everHealthy = true + p.unhealthySince = time.Time{} + if !p.warningActive { + return true + } + log.Debugf("DNS health: group [%s] recovered, emitting event", joinAddrPorts(servers)) + s.statusRecorder.PublishEvent( + proto.SystemEvent_INFO, + proto.SystemEvent_DNS, + "Nameserver group recovered", + "DNS servers are reachable again.", + map[string]string{"upstreams": joinAddrPorts(servers)}, + ) + p.warningActive = false + return true +} + +// projectUnhealthy records an unhealthy tick on p, publishes the +// warning when the emission rules fire, and returns the Enabled flag +// to record in NSGroupState. +func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPort, immediate bool, now time.Time, delay time.Duration) bool { + streakStart := p.unhealthySince.IsZero() + if streakStart { + p.unhealthySince = now + } + reason := unhealthyEmitReason(immediate, p.everHealthy, now.Sub(p.unhealthySince), delay) + switch { + case reason != "" && !p.warningActive: + log.Debugf("DNS health: group [%s] unreachable, emitting event (reason=%s)", joinAddrPorts(servers), reason) + s.statusRecorder.PublishEvent( + proto.SystemEvent_WARNING, + proto.SystemEvent_DNS, + "Nameserver group unreachable", + "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", + map[string]string{"upstreams": joinAddrPorts(servers)}, + ) + p.warningActive = true + case streakStart && reason == "": + // One line per streak, not per tick. + log.Debugf("DNS health: group [%s] unreachable but holding warning for up to %v (overlay-routed, no connected peer)", joinAddrPorts(servers), delay) + } + return false +} + +// warningDelay returns the grace window for the given selected-route +// count. Scales gently: +1s per 100 routes, capped by +// warningDelayBonusCap. Parallel handshakes mean handshake time grows +// much slower than route count, so linear scaling would overcorrect. +// +// TODO: revisit the scaling curve with real-world data — the current +// values are a reasonable starting point, not a measured fit. +func (s *DefaultServer) warningDelay(routeCount int) time.Duration { + bonus := time.Duration(routeCount/100) * time.Second + if bonus > warningDelayBonusCap { + bonus = warningDelayBonusCap + } + return s.warningDelayBase + bonus +} + +// groupHasImmediateUpstream reports whether the group has at least one +// upstream in a classification that bypasses the grace window: public +// (outside the overlay range and not routed), or overlay/routed with a +// Connected peer. +// +// TODO(ipv6): include the v6 overlay prefix once it's plumbed in. +func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap nsHealthSnapshot) bool { + var overlayV4 netip.Prefix + if s.wgInterface != nil { + overlayV4 = s.wgInterface.Address().Network + } + for _, srv := range servers { + addr := srv.Addr().Unmap() + overlay := overlayV4.IsValid() && overlayV4.Contains(addr) + selMatched, selDynamic := haMapContains(snap.selected, addr) + // Treat an unknown (dynamic selected route) as possibly routed: + // the upstream might reach through a dynamic route whose Network + // hasn't resolved yet, and classifying as public would bypass + // the startup grace window. + routed := selMatched || selDynamic + if !overlay && !routed { + return true + } + if actMatched, _ := haMapContains(snap.active, addr); actMatched { + return true + } + } + return false +} + +// collectUpstreamHealth merges health snapshots across handlers, keeping +// the most recent success and failure per upstream when an address appears +// in more than one handler. +func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth { + merged := make(map[netip.AddrPort]UpstreamHealth) + for _, entry := range s.dnsMuxMap { + reporter, ok := entry.handler.(upstreamHealthReporter) + if !ok { + continue + } + for addr, h := range reporter.UpstreamHealth() { + existing, have := merged[addr] + if !have { + merged[addr] = h + continue + } + if h.LastOk.After(existing.LastOk) { + existing.LastOk = h.LastOk + } + if h.LastFail.After(existing.LastFail) { + existing.LastFail = h.LastFail + existing.LastErr = h.LastErr + } + merged[addr] = existing + } + } + return merged +} + +func (s *DefaultServer) startHealthRefresher() { + s.shutdownWg.Add(1) + go func() { + defer s.shutdownWg.Done() + ticker := time.NewTicker(nsGroupHealthRefreshInterval) + defer ticker.Stop() + for { + select { + case <-s.ctx.Done(): + return + case <-ticker.C: + case <-s.healthRefresh: + } + s.refreshHealth() + } + }() +} + +// evaluateNSGroupHealth decides a group's verdict from query records +// alone. Per upstream, the most-recent-in-lookback observation wins. +// Group is Healthy if any upstream is fresh-working, Unhealthy if any +// is fresh-broken with no fresh-working sibling, Undecided otherwise. +func evaluateNSGroupHealth(merged map[netip.AddrPort]UpstreamHealth, servers []netip.AddrPort, now time.Time) (nsGroupVerdict, error) { + anyWorking := false + anyBroken := false + var mostRecentFail time.Time + var mostRecentErr string + + for _, srv := range servers { + h, ok := merged[srv] + if !ok { + continue + } + switch classifyUpstreamHealth(h, now) { + case upstreamFresh: + anyWorking = true + case upstreamBroken: + anyBroken = true + if h.LastFail.After(mostRecentFail) { + mostRecentFail = h.LastFail + mostRecentErr = h.LastErr + } + } + } + + if anyWorking { + return nsVerdictHealthy, nil + } + if anyBroken { + if mostRecentErr == "" { + return nsVerdictUnhealthy, nil + } + return nsVerdictUnhealthy, errors.New(mostRecentErr) + } + return nsVerdictUndecided, nil +} + +// upstreamClassification is the per-upstream verdict within healthLookback. +type upstreamClassification int + +const ( + upstreamStale upstreamClassification = iota + upstreamFresh + upstreamBroken +) + +// classifyUpstreamHealth compares the last ok and last fail timestamps +// against healthLookback and returns which one (if any) counts. Fresh +// wins when both are in-window and ok is newer; broken otherwise. +func classifyUpstreamHealth(h UpstreamHealth, now time.Time) upstreamClassification { + okRecent := !h.LastOk.IsZero() && now.Sub(h.LastOk) <= healthLookback + failRecent := !h.LastFail.IsZero() && now.Sub(h.LastFail) <= healthLookback + switch { + case okRecent && failRecent: + if h.LastOk.After(h.LastFail) { + return upstreamFresh + } + return upstreamBroken + case okRecent: + return upstreamFresh + case failRecent: + return upstreamBroken + } + return upstreamStale +} + +func joinAddrPorts(servers []netip.AddrPort) string { + parts := make([]string, 0, len(servers)) + for _, s := range servers { + parts = append(parts, s.String()) + } + return strings.Join(parts, ", ") +} + +// generateGroupKey returns a stable identity for an NS group so health +// state (everHealthy / warningActive) survives reorderings in the +// configured nameserver or domain lists. +func generateGroupKey(nsGroup *nbdns.NameServerGroup) nsGroupID { + servers := make([]string, 0, len(nsGroup.NameServers)) for _, ns := range nsGroup.NameServers { servers = append(servers, ns.AddrPort().String()) } - return fmt.Sprintf("%v_%v", servers, nsGroup.Domains) + slices.Sort(servers) + domains := slices.Clone(nsGroup.Domains) + slices.Sort(domains) + return nsGroupID(fmt.Sprintf("%v_%v", servers, domains)) } // groupNSGroupsByDomain groups nameserver groups by their match domains @@ -1161,6 +1364,21 @@ func toZone(d domain.Domain) domain.Domain { ) } +// unhealthyEmitReason returns the tag of the rule that fires the +// warning now, or "" if the group is still inside its grace window. +func unhealthyEmitReason(immediate, everHealthy bool, elapsed, delay time.Duration) string { + switch { + case immediate: + return "immediate" + case everHealthy: + return "ever-healthy" + case elapsed >= delay: + return "grace-elapsed" + default: + return "" + } +} + // PopulateManagementDomain populates the DNS cache with management domain func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error { if s.mgmtCacheResolver != nil { diff --git a/client/internal/dns/server_android.go b/client/internal/dns/server_android.go index 7ca12d69d..b2cb26f65 100644 --- a/client/internal/dns/server_android.go +++ b/client/internal/dns/server_android.go @@ -1,5 +1,5 @@ package dns func (s *DefaultServer) initialize() (manager hostManager, err error) { - return newHostManager() + return newHostManager(s.hostsDNSHolder) } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 1026a29fc..722c2abd7 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -6,7 +6,7 @@ import ( "net" "net/netip" "os" - "strings" + "runtime" "testing" "time" @@ -15,6 +15,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -31,8 +32,10 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/client/proto" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter" + "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) @@ -101,16 +104,17 @@ func init() { formatter.SetTextFormatter(log.StandardLogger()) } -func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase { +func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase { var srvs []netip.AddrPort for _, srv := range servers { srvs = append(srvs, srv.AddrPort()) } - return &upstreamResolverBase{ - domain: domain, - upstreamServers: srvs, - cancel: func() {}, + u := &upstreamResolverBase{ + domain: domain.Domain(d), + cancel: func() {}, } + u.addRace(srvs) + return u } func TestUpdateDNSServer(t *testing.T) { @@ -653,74 +657,8 @@ func TestDNSServerStartStop(t *testing.T) { } } -func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { - hostManager := &mockHostConfigurator{} - server := DefaultServer{ - ctx: context.Background(), - service: NewServiceViaMemory(&mocWGIface{}), - localResolver: local.NewResolver(), - handlerChain: NewHandlerChain(), - hostManager: hostManager, - currentConfig: HostDNSConfig{ - Domains: []DomainConfig{ - {false, "domain0", false}, - {false, "domain1", false}, - {false, "domain2", false}, - }, - }, - statusRecorder: peer.NewRecorder("mgm"), - } - - var domainsUpdate string - hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error { - domains := []string{} - for _, item := range config.Domains { - if item.Disabled { - continue - } - domains = append(domains, item.Domain) - } - domainsUpdate = strings.Join(domains, ",") - return nil - } - - deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{ - Domains: []string{"domain1"}, - NameServers: []nbdns.NameServer{ - {IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53}, - }, - }, nil, 0) - - deactivate(nil) - expected := "domain0,domain2" - domains := []string{} - for _, item := range server.currentConfig.Domains { - if item.Disabled { - continue - } - domains = append(domains, item.Domain) - } - got := strings.Join(domains, ",") - if expected != got { - t.Errorf("expected domains list: %q, got %q", expected, got) - } - - reactivate() - expected = "domain0,domain1,domain2" - domains = []string{} - for _, item := range server.currentConfig.Domains { - if item.Disabled { - continue - } - domains = append(domains, item.Domain) - } - got = strings.Join(domains, ",") - if expected != got { - t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate) - } -} - func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) { + skipUnlessAndroid(t) wgIFace, err := createWgInterfaceWithBind(t) if err != nil { t.Fatal("failed to initialize wg interface") @@ -748,6 +686,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) { } func TestDNSPermanent_updateUpstream(t *testing.T) { + skipUnlessAndroid(t) wgIFace, err := createWgInterfaceWithBind(t) if err != nil { t.Fatal("failed to initialize wg interface") @@ -841,6 +780,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) { } func TestDNSPermanent_matchOnly(t *testing.T) { + skipUnlessAndroid(t) wgIFace, err := createWgInterfaceWithBind(t) if err != nil { t.Fatal("failed to initialize wg interface") @@ -913,6 +853,18 @@ func TestDNSPermanent_matchOnly(t *testing.T) { } } +// skipUnlessAndroid marks tests that exercise the mobile-permanent DNS path, +// which only matches a real production setup on android (NewDefaultServerPermanentUpstream +// + androidHostManager). On non-android the desktop host manager replaces it +// during Initialize and the assertion stops making sense. Skipped here until we +// have an android CI runner. +func skipUnlessAndroid(t *testing.T) { + t.Helper() + if runtime.GOOS != "android" { + t.Skip("requires android runner; mobile-permanent path doesn't match production on this OS") + } +} + func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { t.Helper() ov := os.Getenv("NB_WG_KERNEL_DISABLED") @@ -1065,7 +1017,6 @@ type mockHandler struct { func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} func (m *mockHandler) Stop() {} -func (m *mockHandler) ProbeAvailability(context.Context) {} func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) } type mockService struct{} @@ -2085,6 +2036,598 @@ func TestLocalResolverPriorityConstants(t *testing.T) { assert.Equal(t, "local.example.com", localMuxUpdates[0].domain) } +// TestBuildUpstreamHandler_MergesGroupsPerDomain verifies that multiple +// admin-defined nameserver groups targeting the same domain collapse into a +// single handler with each group preserved as a sequential inner list. +func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) { + wgInterface := &mocWGIface{} + service := NewServiceViaMemory(wgInterface) + server := &DefaultServer{ + ctx: context.Background(), + wgInterface: wgInterface, + service: service, + localResolver: local.NewResolver(), + handlerChain: NewHandlerChain(), + hostManager: &noopHostConfigurator{}, + dnsMuxMap: make(registeredHandlerMap), + } + + groups := []*nbdns.NameServerGroup{ + { + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("192.0.2.1"), NSType: nbdns.UDPNameServerType, Port: 53}, + }, + Domains: []string{"example.com"}, + }, + { + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("192.0.2.2"), NSType: nbdns.UDPNameServerType, Port: 53}, + {IP: netip.MustParseAddr("192.0.2.3"), NSType: nbdns.UDPNameServerType, Port: 53}, + }, + Domains: []string{"example.com"}, + }, + } + + muxUpdates, err := server.buildUpstreamHandlerUpdate(groups) + require.NoError(t, err) + require.Len(t, muxUpdates, 1, "same-domain groups should merge into one handler") + assert.Equal(t, "example.com", muxUpdates[0].domain) + assert.Equal(t, PriorityUpstream, muxUpdates[0].priority) + + handler := muxUpdates[0].handler.(*upstreamResolver) + require.Len(t, handler.upstreamServers, 2, "handler should have two groups") + assert.Equal(t, upstreamRace{netip.MustParseAddrPort("192.0.2.1:53")}, handler.upstreamServers[0]) + assert.Equal(t, upstreamRace{ + netip.MustParseAddrPort("192.0.2.2:53"), + netip.MustParseAddrPort("192.0.2.3:53"), + }, handler.upstreamServers[1]) +} + +// TestEvaluateNSGroupHealth covers the records-only verdict. The gate +// (overlay route selected-but-no-active-peer) is intentionally NOT an +// input to the evaluator anymore: the verdict drives the Enabled flag, +// which must always reflect what we actually observed. Gate-aware event +// suppression is tested separately in the projection test. +// +// Matrix per upstream: {no record, fresh Ok, fresh Fail, stale Fail, +// stale Ok, Ok newer than Fail, Fail newer than Ok}. +// Group verdict: any fresh-working → Healthy; any fresh-broken with no +// fresh-working → Unhealthy; otherwise Undecided. +func TestEvaluateNSGroupHealth(t *testing.T) { + now := time.Now() + a := netip.MustParseAddrPort("192.0.2.1:53") + b := netip.MustParseAddrPort("192.0.2.2:53") + + recentOk := UpstreamHealth{LastOk: now.Add(-2 * time.Second)} + recentFail := UpstreamHealth{LastFail: now.Add(-1 * time.Second), LastErr: "timeout"} + staleOk := UpstreamHealth{LastOk: now.Add(-10 * time.Minute)} + staleFail := UpstreamHealth{LastFail: now.Add(-10 * time.Minute), LastErr: "timeout"} + okThenFail := UpstreamHealth{ + LastOk: now.Add(-10 * time.Second), + LastFail: now.Add(-1 * time.Second), + LastErr: "timeout", + } + failThenOk := UpstreamHealth{ + LastOk: now.Add(-1 * time.Second), + LastFail: now.Add(-10 * time.Second), + LastErr: "timeout", + } + + tests := []struct { + name string + health map[netip.AddrPort]UpstreamHealth + servers []netip.AddrPort + wantVerdict nsGroupVerdict + wantErrSubst string + }{ + { + name: "no record, undecided", + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUndecided, + }, + { + name: "fresh success, healthy", + health: map[netip.AddrPort]UpstreamHealth{a: recentOk}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictHealthy, + }, + { + name: "fresh failure, unhealthy", + health: map[netip.AddrPort]UpstreamHealth{a: recentFail}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUnhealthy, + wantErrSubst: "timeout", + }, + { + name: "only stale success, undecided", + health: map[netip.AddrPort]UpstreamHealth{a: staleOk}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUndecided, + }, + { + name: "only stale failure, undecided", + health: map[netip.AddrPort]UpstreamHealth{a: staleFail}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUndecided, + }, + { + name: "both fresh, fail newer, unhealthy", + health: map[netip.AddrPort]UpstreamHealth{a: okThenFail}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUnhealthy, + wantErrSubst: "timeout", + }, + { + name: "both fresh, ok newer, healthy", + health: map[netip.AddrPort]UpstreamHealth{a: failThenOk}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictHealthy, + }, + { + name: "two upstreams, one success wins", + health: map[netip.AddrPort]UpstreamHealth{ + a: recentFail, + b: recentOk, + }, + servers: []netip.AddrPort{a, b}, + wantVerdict: nsVerdictHealthy, + }, + { + name: "two upstreams, one fail one unseen, unhealthy", + health: map[netip.AddrPort]UpstreamHealth{ + a: recentFail, + }, + servers: []netip.AddrPort{a, b}, + wantVerdict: nsVerdictUnhealthy, + wantErrSubst: "timeout", + }, + { + name: "two upstreams, all recent failures, unhealthy", + health: map[netip.AddrPort]UpstreamHealth{ + a: {LastFail: now.Add(-5 * time.Second), LastErr: "timeout"}, + b: {LastFail: now.Add(-1 * time.Second), LastErr: "SERVFAIL"}, + }, + servers: []netip.AddrPort{a, b}, + wantVerdict: nsVerdictUnhealthy, + wantErrSubst: "SERVFAIL", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + verdict, err := evaluateNSGroupHealth(tc.health, tc.servers, now) + assert.Equal(t, tc.wantVerdict, verdict, "verdict mismatch") + if tc.wantErrSubst != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErrSubst) + } else { + assert.NoError(t, err) + } + }) + } +} + +// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed +// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates +// without spinning up real handlers. +type healthStubHandler struct { + health map[netip.AddrPort]UpstreamHealth +} + +func (h *healthStubHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} +func (h *healthStubHandler) Stop() {} +func (h *healthStubHandler) ID() types.HandlerID { return "health-stub" } +func (h *healthStubHandler) UpstreamHealth() map[netip.AddrPort]UpstreamHealth { + return h.health +} + +// TestProjection_SteadyStateIsSilent guards against duplicate events: +// while a group stays Unhealthy tick after tick, only the first +// Unhealthy transition may emit. Same for staying Healthy. +func TestProjection_SteadyStateIsSilent(t *testing.T) { + fx := newProjTestFixture(t) + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "first fail emits warning") + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.tick() + fx.expectNoEvent("staying unhealthy must not re-emit") + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectEvent("recovered", "recovery on transition") + + fx.tick() + fx.tick() + fx.expectNoEvent("staying healthy must not re-emit") +} + +// projTestFixture is the common setup for the projection tests: a +// single-upstream group whose route classification the test can flip by +// assigning to selected/active. Callers drive failures/successes by +// mutating stub.health and calling refreshHealth. +type projTestFixture struct { + t *testing.T + recorder *peer.Status + events <-chan *proto.SystemEvent + server *DefaultServer + stub *healthStubHandler + group *nbdns.NameServerGroup + srv netip.AddrPort + selected route.HAMap + active route.HAMap +} + +func newProjTestFixture(t *testing.T) *projTestFixture { + t.Helper() + recorder := peer.NewRecorder("mgm") + sub := recorder.SubscribeToEvents() + t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) }) + + srv := netip.MustParseAddrPort("100.64.0.1:53") + fx := &projTestFixture{ + t: t, + recorder: recorder, + events: sub.Events(), + stub: &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{}}, + srv: srv, + group: &nbdns.NameServerGroup{ + Domains: []string{"example.com"}, + NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}}, + }, + } + fx.server = &DefaultServer{ + ctx: context.Background(), + wgInterface: &mocWGIface{}, + statusRecorder: recorder, + dnsMuxMap: make(registeredHandlerMap), + selectedRoutes: func() route.HAMap { return fx.selected }, + activeRoutes: func() route.HAMap { return fx.active }, + warningDelayBase: defaultWarningDelayBase, + } + fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream} + + fx.server.mux.Lock() + fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group}) + fx.server.mux.Unlock() + return fx +} + +func (f *projTestFixture) setHealth(h UpstreamHealth) { + f.stub.health = map[netip.AddrPort]UpstreamHealth{f.srv: h} +} + +func (f *projTestFixture) tick() []peer.NSGroupState { + f.server.refreshHealth() + return f.recorder.GetDNSStates() +} + +func (f *projTestFixture) expectNoEvent(why string) { + f.t.Helper() + select { + case evt := <-f.events: + f.t.Fatalf("unexpected event (%s): %+v", why, evt) + case <-time.After(100 * time.Millisecond): + } +} + +func (f *projTestFixture) expectEvent(substr, why string) *proto.SystemEvent { + f.t.Helper() + select { + case evt := <-f.events: + assert.Contains(f.t, evt.Message, substr, why) + return evt + case <-time.After(time.Second): + f.t.Fatalf("expected event (%s) with %q", why, substr) + return nil + } +} + +var overlayNetForTest = netip.MustParsePrefix("100.64.0.0/16") +var overlayMapForTest = route.HAMap{"overlay": {{Network: overlayNetForTest}}} + +// TestProjection_PublicFailEmitsImmediately covers rule 1: an upstream +// that is not inside any selected route (public DNS) fires the warning +// on the first Unhealthy tick, no grace period. +func TestProjection_PublicFailEmitsImmediately(t *testing.T) { + fx := newProjTestFixture(t) + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + states := fx.tick() + require.Len(t, states, 1) + assert.False(t, states[0].Enabled) + fx.expectEvent("unreachable", "public DNS failure") +} + +// TestProjection_OverlayConnectedFailEmitsImmediately covers rule 2: +// the upstream is inside a selected route AND the route has a Connected +// peer. Tunnel is up, failure is real, emit immediately. +func TestProjection_OverlayConnectedFailEmitsImmediately(t *testing.T) { + fx := newProjTestFixture(t) + fx.selected = overlayMapForTest + fx.active = overlayMapForTest + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + states := fx.tick() + require.Len(t, states, 1) + assert.False(t, states[0].Enabled) + fx.expectEvent("unreachable", "overlay + connected failure") +} + +// TestProjection_OverlayNotConnectedDelaysWarning covers rule 3: the +// upstream is routed but no peer is Connected (Connecting/Idle/missing). +// First tick: Unhealthy display, no warning. After the grace window +// elapses with no recovery, the warning fires. +func TestProjection_OverlayNotConnectedDelaysWarning(t *testing.T) { + grace := 50 * time.Millisecond + fx := newProjTestFixture(t) + fx.server.warningDelayBase = grace + fx.selected = overlayMapForTest + // active stays nil: routed but not connected. + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + states := fx.tick() + require.Len(t, states, 1) + assert.False(t, states[0].Enabled, "display must reflect failure even during grace window") + fx.expectNoEvent("first fail tick within grace window") + + time.Sleep(grace + 10*time.Millisecond) + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "warning after grace window") +} + +// TestProjection_OverlayAddrNoRouteDelaysWarning covers an upstream +// whose address is inside the WireGuard overlay range but is not +// covered by any selected route (peer-to-peer DNS without an explicit +// route). Until a peer reports Connected for that address, startup +// failures must be held just like the routed case. +func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) { + recorder := peer.NewRecorder("mgm") + sub := recorder.SubscribeToEvents() + t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) }) + + overlayPeer := netip.MustParseAddrPort("100.66.100.5:53") + server := &DefaultServer{ + ctx: context.Background(), + wgInterface: &mocWGIface{}, + statusRecorder: recorder, + dnsMuxMap: make(registeredHandlerMap), + selectedRoutes: func() route.HAMap { return nil }, + activeRoutes: func() route.HAMap { return nil }, + warningDelayBase: 50 * time.Millisecond, + } + group := &nbdns.NameServerGroup{ + Domains: []string{"example.com"}, + NameServers: []nbdns.NameServer{{IP: overlayPeer.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlayPeer.Port())}}, + } + stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{ + overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}, + }} + server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream} + + server.mux.Lock() + server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) + server.mux.Unlock() + server.refreshHealth() + + select { + case evt := <-sub.Events(): + t.Fatalf("unexpected event during grace window: %+v", evt) + case <-time.After(100 * time.Millisecond): + } + + time.Sleep(60 * time.Millisecond) + stub.health = map[netip.AddrPort]UpstreamHealth{overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}} + server.refreshHealth() + + select { + case evt := <-sub.Events(): + assert.Contains(t, evt.Message, "unreachable") + case <-time.After(time.Second): + t.Fatal("expected warning after grace window") + } +} + +// TestProjection_StopClearsHealthState verifies that Stop wipes the +// per-group projection state so a subsequent Start doesn't inherit +// sticky flags (notably everHealthy) that would bypass the grace +// window during the next peer handshake. +func TestProjection_StopClearsHealthState(t *testing.T) { + wgIface := &mocWGIface{} + server := &DefaultServer{ + ctx: context.Background(), + wgInterface: wgIface, + service: NewServiceViaMemory(wgIface), + hostManager: &noopHostConfigurator{}, + extraDomains: map[domain.Domain]int{}, + dnsMuxMap: make(registeredHandlerMap), + statusRecorder: peer.NewRecorder("mgm"), + selectedRoutes: func() route.HAMap { return nil }, + activeRoutes: func() route.HAMap { return nil }, + warningDelayBase: defaultWarningDelayBase, + currentConfigHash: ^uint64(0), + } + server.ctx, server.ctxCancel = context.WithCancel(context.Background()) + + srv := netip.MustParseAddrPort("8.8.8.8:53") + group := &nbdns.NameServerGroup{ + Domains: []string{"example.com"}, + NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}}, + } + stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}} + server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream} + + server.mux.Lock() + server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) + server.mux.Unlock() + server.refreshHealth() + + server.healthProjectMu.Lock() + p, ok := server.nsGroupProj[generateGroupKey(group)] + server.healthProjectMu.Unlock() + require.True(t, ok, "projection state should exist after tick") + require.True(t, p.everHealthy, "tick with success must set everHealthy") + + server.Stop() + + server.healthProjectMu.Lock() + cleared := server.nsGroupProj == nil + server.healthProjectMu.Unlock() + assert.True(t, cleared, "Stop must clear nsGroupProj") +} + +// TestProjection_OverlayRecoversDuringGrace covers the happy path of +// rule 3: startup failures while the peer is handshaking, then the peer +// comes up and a query succeeds before the grace window elapses. No +// warning should ever have fired, and no recovery either. +func TestProjection_OverlayRecoversDuringGrace(t *testing.T) { + fx := newProjTestFixture(t) + fx.server.warningDelayBase = 200 * time.Millisecond + fx.selected = overlayMapForTest + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectNoEvent("fail within grace, warning suppressed") + + fx.active = overlayMapForTest + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + states := fx.tick() + require.Len(t, states, 1) + assert.True(t, states[0].Enabled) + fx.expectNoEvent("recovery without prior warning must not emit") +} + +// TestProjection_RecoveryOnlyAfterWarning enforces the invariant the +// whole design leans on: recovery events only appear when a warning +// event was actually emitted for the current streak. A Healthy verdict +// without a prior warning is silent, so the user never sees "recovered" +// out of thin air. +func TestProjection_RecoveryOnlyAfterWarning(t *testing.T) { + fx := newProjTestFixture(t) + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + states := fx.tick() + require.Len(t, states, 1) + assert.True(t, states[0].Enabled) + fx.expectNoEvent("first healthy tick should not recover anything") + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "public fail emits immediately") + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectEvent("recovered", "recovery follows real warning") + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "second cycle warning") + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectEvent("recovered", "second cycle recovery") +} + +// TestProjection_EverHealthyOverridesDelay covers rule 4: once a group +// has ever been Healthy, subsequent failures skip the grace window even +// if classification says "routed + not connected". The system has +// proved it can work, so any new failure is real. +func TestProjection_EverHealthyOverridesDelay(t *testing.T) { + fx := newProjTestFixture(t) + // Large base so any emission must come from the everHealthy bypass, not elapsed time. + fx.server.warningDelayBase = time.Hour + fx.selected = overlayMapForTest + fx.active = overlayMapForTest + + // Establish "ever healthy". + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectNoEvent("first healthy tick") + + // Peer drops. Query fails. Routed + not connected → normally grace, + // but everHealthy flag bypasses it. + fx.active = nil + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "failure after ever-healthy must be immediate") +} + +// TestProjection_ReconnectBlipEmitsPair covers the explicit tradeoff +// from the design discussion: once a group has been healthy, a brief +// reconnect that produces a failing tick will fire warning + recovery. +// This is by design: user-visible blips are accurate signal, not noise. +func TestProjection_ReconnectBlipEmitsPair(t *testing.T) { + fx := newProjTestFixture(t) + fx.selected = overlayMapForTest + fx.active = overlayMapForTest + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "blip warning") + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectEvent("recovered", "blip recovery") +} + +// TestProjection_MixedGroupEmitsImmediately covers the multi-upstream +// rule: a group with at least one public upstream is in the "immediate" +// category regardless of the other upstreams' routing, because the +// public one has no peer-startup excuse. Prevents public-DNS failures +// from being hidden behind a routed sibling. +func TestProjection_MixedGroupEmitsImmediately(t *testing.T) { + recorder := peer.NewRecorder("mgm") + sub := recorder.SubscribeToEvents() + t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) }) + events := sub.Events() + + public := netip.MustParseAddrPort("8.8.8.8:53") + overlay := netip.MustParseAddrPort("100.64.0.1:53") + overlayMap := route.HAMap{"overlay": {{Network: netip.MustParsePrefix("100.64.0.0/16")}}} + + server := &DefaultServer{ + ctx: context.Background(), + statusRecorder: recorder, + dnsMuxMap: make(registeredHandlerMap), + selectedRoutes: func() route.HAMap { return overlayMap }, + activeRoutes: func() route.HAMap { return nil }, + warningDelayBase: time.Hour, + } + group := &nbdns.NameServerGroup{ + Domains: []string{"example.com"}, + NameServers: []nbdns.NameServer{ + {IP: public.Addr(), NSType: nbdns.UDPNameServerType, Port: int(public.Port())}, + {IP: overlay.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlay.Port())}, + }, + } + stub := &healthStubHandler{ + health: map[netip.AddrPort]UpstreamHealth{ + public: {LastFail: time.Now(), LastErr: "servfail"}, + overlay: {LastFail: time.Now(), LastErr: "timeout"}, + }, + } + server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream} + + server.mux.Lock() + server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) + server.mux.Unlock() + server.refreshHealth() + + select { + case evt := <-events: + assert.Contains(t, evt.Message, "unreachable") + case <-time.After(time.Second): + t.Fatal("expected immediate warning because group contains a public upstream") + } +} + func TestDNSLoopPrevention(t *testing.T) { wgInterface := &mocWGIface{} service := NewServiceViaMemory(wgInterface) @@ -2183,17 +2726,18 @@ func TestDNSLoopPrevention(t *testing.T) { if tt.expectedHandlers > 0 { handler := muxUpdates[0].handler.(*upstreamResolver) - assert.Len(t, handler.upstreamServers, len(tt.expectedServers)) + flat := handler.flatUpstreams() + assert.Len(t, flat, len(tt.expectedServers)) if tt.shouldFilterOwnIP { - for _, upstream := range handler.upstreamServers { + for _, upstream := range flat { assert.NotEqual(t, dnsServerIP, upstream.Addr()) } } for _, expected := range tt.expectedServers { found := false - for _, upstream := range handler.upstreamServers { + for _, upstream := range flat { if upstream.Addr() == expected { found = true break diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index 573dff540..bd301e177 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/netip" + "slices" "time" "github.com/godbus/dbus/v5" @@ -40,10 +41,17 @@ const ( ) type systemdDbusConfigurator struct { - dbusLinkObject dbus.ObjectPath - ifaceName string + dbusLinkObject dbus.ObjectPath + ifaceName string + wgIndex int + origNameservers []netip.Addr } +const ( + systemdDbusLinkDNSProperty = systemdDbusLinkInterface + ".DNS" + systemdDbusLinkDefaultRouteProperty = systemdDbusLinkInterface + ".DefaultRoute" +) + // the types below are based on dbus specification, each field is mapped to a dbus type // see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types // see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types @@ -79,10 +87,145 @@ func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, e log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index) - return &systemdDbusConfigurator{ + c := &systemdDbusConfigurator{ dbusLinkObject: dbus.ObjectPath(s), ifaceName: wgInterface, - }, nil + wgIndex: iface.Index, + } + + origNameservers, err := c.captureOriginalNameservers() + switch { + case err != nil: + log.Warnf("capture original nameservers from systemd-resolved: %v", err) + case len(origNameservers) == 0: + log.Warnf("no original nameservers captured from systemd-resolved default-route links; DNS fallback will be empty") + default: + log.Debugf("captured %d original nameservers from systemd-resolved default-route links: %v", len(origNameservers), origNameservers) + } + c.origNameservers = origNameservers + return c, nil +} + +// captureOriginalNameservers reads per-link DNS from systemd-resolved for +// every default-route link except our own WG link. Non-default-route links +// (VPNs, docker bridges) are skipped because their upstreams wouldn't +// actually serve host queries. +func (s *systemdDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) { + ifaces, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("list interfaces: %w", err) + } + + seen := make(map[netip.Addr]struct{}) + var out []netip.Addr + for _, iface := range ifaces { + if !s.isCandidateLink(iface) { + continue + } + linkPath, err := getSystemdLinkPath(iface.Index) + if err != nil || !isSystemdLinkDefaultRoute(linkPath) { + continue + } + for _, addr := range readSystemdLinkDNS(linkPath) { + addr = normalizeSystemdAddr(addr, iface.Name) + if !addr.IsValid() { + continue + } + if _, dup := seen[addr]; dup { + continue + } + seen[addr] = struct{}{} + out = append(out, addr) + } + } + return out, nil +} + +func (s *systemdDbusConfigurator) isCandidateLink(iface net.Interface) bool { + if iface.Index == s.wgIndex { + return false + } + if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 { + return false + } + return true +} + +// normalizeSystemdAddr unmaps v4-mapped-v6, drops unspecified, and reattaches +// the link's iface name as zone for link-local v6 (Link.DNS strips it). +// Returns the zero Addr to signal "skip this entry". +func normalizeSystemdAddr(addr netip.Addr, ifaceName string) netip.Addr { + addr = addr.Unmap() + if !addr.IsValid() || addr.IsUnspecified() { + return netip.Addr{} + } + if addr.IsLinkLocalUnicast() { + return addr.WithZone(ifaceName) + } + return addr +} + +func getSystemdLinkPath(ifIndex int) (dbus.ObjectPath, error) { + obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode) + if err != nil { + return "", fmt.Errorf("dbus resolve1: %w", err) + } + defer closeConn() + var p string + if err := obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, int32(ifIndex)).Store(&p); err != nil { + return "", err + } + return dbus.ObjectPath(p), nil +} + +func isSystemdLinkDefaultRoute(linkPath dbus.ObjectPath) bool { + obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath) + if err != nil { + return false + } + defer closeConn() + v, err := obj.GetProperty(systemdDbusLinkDefaultRouteProperty) + if err != nil { + return false + } + b, ok := v.Value().(bool) + return ok && b +} + +func readSystemdLinkDNS(linkPath dbus.ObjectPath) []netip.Addr { + obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath) + if err != nil { + return nil + } + defer closeConn() + v, err := obj.GetProperty(systemdDbusLinkDNSProperty) + if err != nil { + return nil + } + entries, ok := v.Value().([][]any) + if !ok { + return nil + } + var out []netip.Addr + for _, entry := range entries { + if len(entry) < 2 { + continue + } + raw, ok := entry[1].([]byte) + if !ok { + continue + } + addr, ok := netip.AddrFromSlice(raw) + if !ok { + continue + } + out = append(out, addr) + } + return out +} + +func (s *systemdDbusConfigurator) getOriginalNameservers() []netip.Addr { + return slices.Clone(s.origNameservers) } func (s *systemdDbusConfigurator) supportCustomPort() bool { diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 39064f26c..a4f713d68 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -1,3 +1,32 @@ +// Package dns implements the client-side DNS stack: listener/service on the +// peer's tunnel address, handler chain that routes questions by domain and +// priority, and upstream resolvers that forward what remains to configured +// nameservers. +// +// # Upstream resolution and the race model +// +// When two or more nameserver groups target the same domain, DefaultServer +// merges them into one upstream handler whose state is: +// +// upstreamResolverBase +// └── upstreamServers []upstreamRace // one entry per source NS group +// └── []netip.AddrPort // primary, fallback, ... +// +// Each source nameserver group contributes one upstreamRace. Within a race +// upstreams are tried in order: the next is used only on failure (timeout, +// SERVFAIL, REFUSED, no response). NXDOMAIN is a valid answer and stops +// the walk. When more than one race exists, ServeDNS fans out one +// goroutine per race and returns the first valid answer, cancelling the +// rest. A handler with a single race skips the fan-out. +// +// # Health projection +// +// Query outcomes are recorded per-upstream in UpstreamHealth. The server +// periodically merges these snapshots across handlers and projects them +// into peer.NSGroupState. There is no active probing: a group is marked +// unhealthy only when every seen upstream has a recent failure and none +// has a recent success. Healthy→unhealthy fires a single +// SystemEvent_WARNING; steady-state refreshes do not duplicate it. package dns import ( @@ -11,11 +40,8 @@ import ( "slices" "strings" "sync" - "sync/atomic" "time" - "github.com/cenkalti/backoff/v4" - "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/tun/netstack" @@ -25,7 +51,8 @@ import ( "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) var currentMTU uint16 = iface.DefaultMTU @@ -67,15 +94,17 @@ const ( // Set longer than UpstreamTimeout to ensure context timeout takes precedence ClientTimeout = 5 * time.Second - reactivatePeriod = 30 * time.Second - probeTimeout = 2 * time.Second - // ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP // payload from the tunnel MTU. ipUDPHeaderSize = 60 + 8 -) -const testRecord = "com." + // raceMaxTotalTimeout caps the combined time spent walking all upstreams + // within one race, so a slow primary can't eat the whole race budget. + raceMaxTotalTimeout = 5 * time.Second + // raceMinPerUpstreamTimeout is the floor applied when dividing + // raceMaxTotalTimeout across upstreams within a race. + raceMinPerUpstreamTimeout = 2 * time.Second +) const ( protoUDP = "udp" @@ -84,6 +113,69 @@ const ( type dnsProtocolKey struct{} +type upstreamProtocolKey struct{} + +// upstreamProtocolResult holds the protocol used for the upstream exchange. +// Stored as a pointer in context so the exchange function can set it. +type upstreamProtocolResult struct { + protocol string +} + +type upstreamClient interface { + exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) +} + +type UpstreamResolver interface { + serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error) + upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) +} + +// upstreamRace is an ordered list of upstreams derived from one configured +// nameserver group. Order matters: the first upstream is tried first, the +// second only on failure, and so on. Multiple upstreamRace values coexist +// inside one resolver when overlapping nameserver groups target the same +// domain; those races run in parallel and the first valid answer wins. +type upstreamRace []netip.AddrPort + +// UpstreamHealth is the last query-path outcome for a single upstream, +// consumed by nameserver-group status projection. +type UpstreamHealth struct { + LastOk time.Time + LastFail time.Time + LastErr string +} + +type upstreamResolverBase struct { + ctx context.Context + cancel context.CancelFunc + upstreamClient upstreamClient + upstreamServers []upstreamRace + domain domain.Domain + upstreamTimeout time.Duration + + healthMu sync.RWMutex + health map[netip.AddrPort]*UpstreamHealth + + statusRecorder *peer.Status + // selectedRoutes returns the current set of client routes the admin + // has enabled. Called lazily from the query hot path when an upstream + // might need a tunnel-bound client (iOS) and from health projection. + selectedRoutes func() route.HAMap +} + +type upstreamFailure struct { + upstream netip.AddrPort + reason string +} + +type raceResult struct { + msg *dns.Msg + upstream netip.AddrPort + protocol string + ede string + failures []upstreamFailure +} + // contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context. func contextWithDNSProtocol(ctx context.Context, network string) context.Context { return context.WithValue(ctx, dnsProtocolKey{}, network) @@ -100,16 +192,8 @@ func dnsProtocolFromContext(ctx context.Context) string { return "" } -type upstreamProtocolKey struct{} - -// upstreamProtocolResult holds the protocol used for the upstream exchange. -// Stored as a pointer in context so the exchange function can set it. -type upstreamProtocolResult struct { - protocol string -} - -// contextWithupstreamProtocolResult stores a mutable result holder in the context. -func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) { +// contextWithUpstreamProtocolResult stores a mutable result holder in the context. +func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) { r := &upstreamProtocolResult{} return context.WithValue(ctx, upstreamProtocolKey{}, r), r } @@ -124,67 +208,37 @@ func setUpstreamProtocol(ctx context.Context, protocol string) { } } -type upstreamClient interface { - exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) -} - -type UpstreamResolver interface { - serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error) - upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) -} - -type upstreamResolverBase struct { - ctx context.Context - cancel context.CancelFunc - upstreamClient upstreamClient - upstreamServers []netip.AddrPort - domain string - disabled bool - successCount atomic.Int32 - mutex sync.Mutex - reactivatePeriod time.Duration - upstreamTimeout time.Duration - wg sync.WaitGroup - - deactivate func(error) - reactivate func() - statusRecorder *peer.Status - routeMatch func(netip.Addr) bool -} - -type upstreamFailure struct { - upstream netip.AddrPort - reason string -} - -func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase { +func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain.Domain) *upstreamResolverBase { ctx, cancel := context.WithCancel(ctx) return &upstreamResolverBase{ - ctx: ctx, - cancel: cancel, - domain: domain, - upstreamTimeout: UpstreamTimeout, - reactivatePeriod: reactivatePeriod, - statusRecorder: statusRecorder, + ctx: ctx, + cancel: cancel, + domain: d, + upstreamTimeout: UpstreamTimeout, + statusRecorder: statusRecorder, } } // String returns a string representation of the upstream resolver func (u *upstreamResolverBase) String() string { - return fmt.Sprintf("Upstream %s", u.upstreamServers) + return fmt.Sprintf("Upstream %s", u.flatUpstreams()) } -// ID returns the unique handler ID +// ID returns the unique handler ID. Race groupings and within-race +// ordering are both part of the identity: [[A,B]] and [[A],[B]] query +// the same servers but with different semantics (serial fallback vs +// parallel race), so their handlers must not collide. func (u *upstreamResolverBase) ID() types.HandlerID { - servers := slices.Clone(u.upstreamServers) - slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) }) - hash := sha256.New() - hash.Write([]byte(u.domain + ":")) - for _, s := range servers { - hash.Write([]byte(s.String())) - hash.Write([]byte("|")) + hash.Write([]byte(u.domain.PunycodeString() + ":")) + for _, race := range u.upstreamServers { + hash.Write([]byte("[")) + for _, s := range race { + hash.Write([]byte(s.String())) + hash.Write([]byte("|")) + } + hash.Write([]byte("]")) } return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) } @@ -194,13 +248,31 @@ func (u *upstreamResolverBase) MatchSubdomains() bool { } func (u *upstreamResolverBase) Stop() { - log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) + log.Debugf("stopping serving DNS for upstreams %s", u.flatUpstreams()) u.cancel() +} - u.mutex.Lock() - u.wg.Wait() - u.mutex.Unlock() +// flatUpstreams is for logging and ID hashing only, not for dispatch. +func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort { + var out []netip.AddrPort + for _, g := range u.upstreamServers { + out = append(out, g...) + } + return out +} +// setSelectedRoutes swaps the accessor used to classify overlay-routed +// upstreams. Called when route sources are wired after the handler was +// built (permanent / iOS constructors). +func (u *upstreamResolverBase) setSelectedRoutes(selected func() route.HAMap) { + u.selectedRoutes = selected +} + +func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) { + if len(servers) == 0 { + return + } + u.upstreamServers = append(u.upstreamServers, slices.Clone(servers)) } // ServeDNS handles a DNS request @@ -242,82 +314,201 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { } func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) { - timeout := u.upstreamTimeout - if len(u.upstreamServers) > 1 { - maxTotal := 5 * time.Second - minPerUpstream := 2 * time.Second - scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers)) - if scaledTimeout > minPerUpstream { - timeout = scaledTimeout - } else { - timeout = minPerUpstream - } + groups := u.upstreamServers + switch len(groups) { + case 0: + return false, nil + case 1: + return u.tryOnlyRace(ctx, w, r, groups[0], logger) + default: + return u.raceAll(ctx, w, r, groups, logger) + } +} + +func (u *upstreamResolverBase) tryOnlyRace(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, group upstreamRace, logger *log.Entry) (bool, []upstreamFailure) { + res := u.tryRace(ctx, r, group) + if res.msg == nil { + return false, res.failures + } + if res.ede != "" { + resutil.SetMeta(w, "ede", res.ede) + } + u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger) + return true, res.failures +} + +// raceAll runs one worker per group in parallel, taking the first valid +// answer and cancelling the rest. +func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, groups []upstreamRace, logger *log.Entry) (bool, []upstreamFailure) { + raceCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Buffer sized to len(groups) so workers never block on send, even + // after the coordinator has returned. + results := make(chan raceResult, len(groups)) + for _, g := range groups { + // tryRace clones the request per attempt, so workers never share + // a *dns.Msg and concurrent EDNS0 mutations can't race. + go func(g upstreamRace) { + results <- u.tryRace(raceCtx, r, g) + }(g) } var failures []upstreamFailure - for _, upstream := range u.upstreamServers { - if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil { - failures = append(failures, *failure) - } else { - return true, failures + for range groups { + select { + case res := <-results: + failures = append(failures, res.failures...) + if res.msg != nil { + if res.ede != "" { + resutil.SetMeta(w, "ede", res.ede) + } + u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger) + return true, failures + } + case <-ctx.Done(): + return false, failures } } return false, failures } -// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream. -func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure { - var rm *dns.Msg - var t time.Duration - var err error +func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult { + timeout := u.upstreamTimeout + if len(group) > 1 { + // Cap the whole walk at raceMaxTotalTimeout: per-upstream timeouts + // still honor raceMinPerUpstreamTimeout as a floor for correctness + // on slow links, but the outer context ensures the combined walk + // cannot exceed the cap regardless of group size. + timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout) + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, raceMaxTotalTimeout) + defer cancel() + } + + var failures []upstreamFailure + for _, upstream := range group { + if ctx.Err() != nil { + return raceResult{failures: failures} + } + // Clone the request per attempt: the exchange path mutates EDNS0 + // options in-place, so reusing the same *dns.Msg across sequential + // upstreams would carry those mutations (e.g. a reduced UDP size) + // into the next attempt. + res, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout) + if failure != nil { + failures = append(failures, *failure) + continue + } + res.failures = failures + return res + } + return raceResult{failures: failures} +} + +func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (raceResult, *upstreamFailure) { + ctx, cancel := context.WithTimeout(parentCtx, timeout) + defer cancel() + ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx) // Advertise EDNS0 so the upstream may include Extended DNS Errors // (RFC 8914) in failure responses; we use those to short-circuit // failover for definitive answers like DNSSEC validation failures. - // Operate on a copy so the inbound request is unchanged: a client that - // did not advertise EDNS0 must not see an OPT in the response. + // The caller already passed a per-attempt copy, so we can mutate r + // directly; hadEdns reflects the original client request's state and + // controls whether we strip the OPT from the response. hadEdns := r.IsEdns0() != nil - reqUp := r if !hadEdns { - reqUp = r.Copy() - reqUp.SetEdns0(upstreamUDPSize(), false) + r.SetEdns0(upstreamUDPSize(), false) } - var startTime time.Time - var upstreamProto *upstreamProtocolResult - func() { - ctx, cancel := context.WithTimeout(parentCtx, timeout) - defer cancel() - ctx, upstreamProto = contextWithupstreamProtocolResult(ctx) - startTime = time.Now() - rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp) - }() + startTime := time.Now() + rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r) if err != nil { - return u.handleUpstreamError(err, upstream, startTime) + // A parent cancellation (e.g., another race won and the coordinator + // cancelled the losers) is not an upstream failure. Check both the + // error chain and the parent context: a transport may surface the + // cancellation as a read/deadline error rather than context.Canceled. + if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) { + return raceResult{}, &upstreamFailure{upstream: upstream, reason: "canceled"} + } + failure := u.handleUpstreamError(err, upstream, startTime) + u.markUpstreamFail(upstream, failure.reason) + return raceResult{}, failure } if rm == nil || !rm.Response { - return &upstreamFailure{upstream: upstream, reason: "no response"} + u.markUpstreamFail(upstream, "no response") + return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"} + } + + proto := "" + if upstreamProto != nil { + proto = upstreamProto.protocol } if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused { if code, ok := nonRetryableEDE(rm); ok { - resutil.SetMeta(w, "ede", edeName(code)) if !hadEdns { stripOPT(rm) } - u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger) - return nil + u.markUpstreamOk(upstream) + return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil } - return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]} + reason := dns.RcodeToString[rm.Rcode] + u.markUpstreamFail(upstream, reason) + return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason} } if !hadEdns { stripOPT(rm) } - u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger) - return nil + + u.markUpstreamOk(upstream) + return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil +} + +// healthEntry returns the mutable health record for addr, lazily creating +// the map and the entry. Caller must hold u.healthMu. +func (u *upstreamResolverBase) healthEntry(addr netip.AddrPort) *UpstreamHealth { + if u.health == nil { + u.health = make(map[netip.AddrPort]*UpstreamHealth) + } + h := u.health[addr] + if h == nil { + h = &UpstreamHealth{} + u.health[addr] = h + } + return h +} + +func (u *upstreamResolverBase) markUpstreamOk(addr netip.AddrPort) { + u.healthMu.Lock() + defer u.healthMu.Unlock() + h := u.healthEntry(addr) + h.LastOk = time.Now() + h.LastFail = time.Time{} + h.LastErr = "" +} + +func (u *upstreamResolverBase) markUpstreamFail(addr netip.AddrPort, reason string) { + u.healthMu.Lock() + defer u.healthMu.Unlock() + h := u.healthEntry(addr) + h.LastFail = time.Now() + h.LastErr = reason +} + +// UpstreamHealth returns a snapshot of per-upstream query outcomes. +func (u *upstreamResolverBase) UpstreamHealth() map[netip.AddrPort]UpstreamHealth { + u.healthMu.RLock() + defer u.healthMu.RUnlock() + out := make(map[netip.AddrPort]UpstreamHealth, len(u.health)) + for k, v := range u.health { + out[k] = *v + } + return out } // upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams, @@ -358,12 +549,23 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add return &upstreamFailure{upstream: upstream, reason: reason} } -func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool { - u.successCount.Add(1) +func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string { + if u.statusRecorder == nil { + return "" + } + peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder) + if peerInfo == nil { + return "" + } + + return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo)) +} + +func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, proto string, logger *log.Entry) { resutil.SetMeta(w, "upstream", upstream.String()) - if upstreamProto != nil && upstreamProto.protocol != "" { - resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol) + if proto != "" { + resutil.SetMeta(w, "upstream_protocol", proto) } // Clear Zero bit from external responses to prevent upstream servers from @@ -372,14 +574,11 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn if err := w.WriteMsg(rm); err != nil { logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err) - return true } - - return true } func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) { - totalUpstreams := len(u.upstreamServers) + totalUpstreams := len(u.flatUpstreams()) failedCount := len(failures) failureSummary := formatFailures(failures) @@ -434,119 +633,6 @@ func edeName(code uint16) string { return fmt.Sprintf("EDE %d", code) } -// ProbeAvailability tests all upstream servers simultaneously and -// disables the resolver if none work -func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) { - u.mutex.Lock() - defer u.mutex.Unlock() - - // avoid probe if upstreams could resolve at least one query - if u.successCount.Load() > 0 { - return - } - - var success bool - var mu sync.Mutex - var wg sync.WaitGroup - - var errs *multierror.Error - for _, upstream := range u.upstreamServers { - wg.Add(1) - go func(upstream netip.AddrPort) { - defer wg.Done() - err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond) - if err != nil { - mu.Lock() - errs = multierror.Append(errs, err) - mu.Unlock() - log.Warnf("probing upstream nameserver %s: %s", upstream, err) - return - } - - mu.Lock() - success = true - mu.Unlock() - }(upstream) - } - - wg.Wait() - - select { - case <-ctx.Done(): - return - case <-u.ctx.Done(): - return - default: - } - - // didn't find a working upstream server, let's disable and try later - if !success { - u.disable(errs.ErrorOrNil()) - - if u.statusRecorder == nil { - return - } - - u.statusRecorder.PublishEvent( - proto.SystemEvent_WARNING, - proto.SystemEvent_DNS, - "All upstream servers failed (probe failed)", - "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", - map[string]string{"upstreams": u.upstreamServersString()}, - ) - } -} - -// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response -func (u *upstreamResolverBase) waitUntilResponse() { - exponentialBackOff := &backoff.ExponentialBackOff{ - InitialInterval: 500 * time.Millisecond, - RandomizationFactor: 0.5, - Multiplier: 1.1, - MaxInterval: u.reactivatePeriod, - MaxElapsedTime: 0, - Stop: backoff.Stop, - Clock: backoff.SystemClock, - } - - operation := func() error { - select { - case <-u.ctx.Done(): - return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString())) - default: - } - - for _, upstream := range u.upstreamServers { - if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil { - log.Tracef("upstream check for %s: %s", upstream, err) - } else { - // at least one upstream server is available, stop probing - return nil - } - } - - log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff()) - return fmt.Errorf("upstream check call error") - } - - err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx)) - if err != nil { - if errors.Is(err, context.Canceled) { - log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString()) - } else { - log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err) - } - return - } - - log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) - u.successCount.Add(1) - u.reactivate() - u.mutex.Lock() - u.disabled = false - u.mutex.Unlock() -} - // isTimeout returns true if the given error is a network timeout error. // // Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout @@ -558,45 +644,6 @@ func isTimeout(err error) bool { return false } -func (u *upstreamResolverBase) disable(err error) { - if u.disabled { - return - } - - log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod) - u.successCount.Store(0) - u.deactivate(err) - u.disabled = true - u.wg.Add(1) - go func() { - defer u.wg.Done() - u.waitUntilResponse() - }() -} - -func (u *upstreamResolverBase) upstreamServersString() string { - var servers []string - for _, server := range u.upstreamServers { - servers = append(servers, server.String()) - } - return strings.Join(servers, ", ") -} - -func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error { - mergedCtx, cancel := context.WithTimeout(baseCtx, timeout) - defer cancel() - - if externalCtx != nil { - stop2 := context.AfterFunc(externalCtx, cancel) - defer stop2() - } - - r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) - - _, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r) - return err -} - // clientUDPMaxSize returns the maximum UDP response size the client accepts. func clientUDPMaxSize(r *dns.Msg) int { if opt := r.IsEdns0(); opt != nil { @@ -608,13 +655,10 @@ func clientUDPMaxSize(r *dns.Msg) int { // ExchangeWithFallback exchanges a DNS message with the upstream server. // It first tries to use UDP, and if it is truncated, it falls back to TCP. // If the inbound request came over TCP (via context), it skips the UDP attempt. -// If the passed context is nil, this will use Exchange instead of ExchangeContext. func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) { // If the request came in over TCP, go straight to TCP upstream. if dnsProtocolFromContext(ctx) == protoTCP { - tcpClient := *client - tcpClient.Net = protoTCP - rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream) + rm, t, err := toTCPClient(client).ExchangeContext(ctx, r, upstream) if err != nil { return nil, t, fmt.Errorf("with tcp: %w", err) } @@ -634,18 +678,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u opt.SetUDPSize(maxUDPPayload) } - var ( - rm *dns.Msg - t time.Duration - err error - ) - - if ctx == nil { - rm, t, err = client.Exchange(r, upstream) - } else { - rm, t, err = client.ExchangeContext(ctx, r, upstream) - } - + rm, t, err := client.ExchangeContext(ctx, r, upstream) if err != nil { return nil, t, fmt.Errorf("with udp: %w", err) } @@ -659,15 +692,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u // data than the client's buffer, we could truncate locally and skip // the TCP retry. - tcpClient := *client - tcpClient.Net = protoTCP - - if ctx == nil { - rm, t, err = tcpClient.Exchange(r, upstream) - } else { - rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream) - } - + rm, t, err = toTCPClient(client).ExchangeContext(ctx, r, upstream) if err != nil { return nil, t, fmt.Errorf("with tcp: %w", err) } @@ -681,6 +706,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u return rm, t, nil } +// toTCPClient returns a copy of c configured for TCP. If c's Dialer has a +// *net.UDPAddr bound as LocalAddr (iOS does this to keep the source IP on +// the tunnel interface), it is converted to the equivalent *net.TCPAddr +// so net.Dialer doesn't reject the TCP dial with "mismatched local +// address type". +func toTCPClient(c *dns.Client) *dns.Client { + tcp := *c + tcp.Net = protoTCP + if tcp.Dialer == nil { + return &tcp + } + d := *tcp.Dialer + if ua, ok := d.LocalAddr.(*net.UDPAddr); ok { + d.LocalAddr = &net.TCPAddr{IP: ua.IP, Port: ua.Port, Zone: ua.Zone} + } + tcp.Dialer = &d + return &tcp +} + // ExchangeWithNetstack performs a DNS exchange using netstack for dialing. // This is needed when netstack is enabled to reach peer IPs through the tunnel. func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) { @@ -822,15 +866,36 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State { return bestMatch } -func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string { - if u.statusRecorder == nil { - return "" +// haMapRouteCount returns the total number of routes across all HA +// groups in the map. route.HAMap is keyed by HAUniqueID with slices of +// routes per key, so len(hm) is the number of HA groups, not routes. +func haMapRouteCount(hm route.HAMap) int { + total := 0 + for _, routes := range hm { + total += len(routes) } - - peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder) - if peerInfo == nil { - return "" - } - - return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo)) + return total +} + +// haMapContains checks whether ip is covered by any concrete prefix in +// the HA map. haveDynamic is reported separately: dynamic (domain-based) +// routes carry a placeholder Network that can't be prefix-checked, so we +// can't know at this point whether ip is reached through one. Callers +// decide how to interpret the unknown: health projection treats it as +// "possibly routed" to avoid emitting false-positive warnings during +// startup, while iOS dial selection requires a concrete match before +// binding to the tunnel. +func haMapContains(hm route.HAMap, ip netip.Addr) (matched, haveDynamic bool) { + for _, routes := range hm { + for _, r := range routes { + if r.IsDynamic() { + haveDynamic = true + continue + } + if r.Network.Contains(ip) { + return true, haveDynamic + } + } + } + return false, haveDynamic } diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index 988adb7d2..f7ab48b10 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" nbnet "github.com/netbirdio/netbird/client/net" + "github.com/netbirdio/netbird/shared/management/domain" ) type upstreamResolver struct { @@ -26,9 +27,9 @@ func newUpstreamResolver( _ WGIface, statusRecorder *peer.Status, hostsDNSHolder *hostsDNSHolder, - domain string, + d domain.Domain, ) (*upstreamResolver, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d) c := &upstreamResolver{ upstreamResolverBase: upstreamResolverBase, hostsDNSHolder: hostsDNSHolder, diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 910c3779e..dc841757b 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -12,6 +12,7 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/shared/management/domain" ) type upstreamResolver struct { @@ -24,9 +25,9 @@ func newUpstreamResolver( wgIface WGIface, statusRecorder *peer.Status, _ *hostsDNSHolder, - domain string, + d domain.Domain, ) (*upstreamResolver, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d) nonIOS := &upstreamResolver{ upstreamResolverBase: upstreamResolverBase, nsNet: wgIface.GetNet(), diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 0e04742a0..b989bf0f9 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -15,6 +15,7 @@ import ( "golang.org/x/sys/unix" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/shared/management/domain" ) type upstreamResolverIOS struct { @@ -27,9 +28,9 @@ func newUpstreamResolver( wgIface WGIface, statusRecorder *peer.Status, _ *hostsDNSHolder, - domain string, + d domain.Domain, ) (*upstreamResolverIOS, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d) ios := &upstreamResolverIOS{ upstreamResolverBase: upstreamResolverBase, @@ -62,9 +63,16 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * upstreamIP = upstreamIP.Unmap() } addr := u.wgIface.Address() + var routed bool + if u.selectedRoutes != nil { + // Only a concrete prefix match binds to the tunnel: dialing + // through a private client for an upstream we can't prove is + // routed would break public resolvers. + routed, _ = haMapContains(u.selectedRoutes(), upstreamIP) + } needsPrivate := addr.Network.Contains(upstreamIP) || addr.IPv6Net.Contains(upstreamIP) || - (u.routeMatch != nil && u.routeMatch(upstreamIP)) + routed if needsPrivate { log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream) client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout) @@ -73,8 +81,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * } } - // Cannot use client.ExchangeContext because it overwrites our Dialer - return ExchangeWithFallback(nil, client, r, upstream) + return ExchangeWithFallback(ctx, client, r, upstream) } // GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface. diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index d6aec05ca..8b3c589f1 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -6,6 +6,7 @@ import ( "net" "net/netip" "strings" + "sync/atomic" "testing" "time" @@ -73,7 +74,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())) } } - resolver.upstreamServers = servers + resolver.addRace(servers) resolver.upstreamTimeout = testCase.timeout if testCase.cancelCTX { cancel() @@ -132,20 +133,10 @@ func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) { return "", nil } -type mockUpstreamResolver struct { - r *dns.Msg - rtt time.Duration - err error -} - -// exchange mock implementation of exchange from upstreamResolver -func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) { - return c.r, c.rtt, c.err -} - type mockUpstreamResponse struct { - msg *dns.Msg - err error + msg *dns.Msg + err error + delay time.Duration } type mockUpstreamResolverPerServer struct { @@ -153,63 +144,19 @@ type mockUpstreamResolverPerServer struct { rtt time.Duration } -func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) { - if r, ok := c.responses[upstream]; ok { - return r.msg, c.rtt, r.err +func (c mockUpstreamResolverPerServer) exchange(ctx context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) { + r, ok := c.responses[upstream] + if !ok { + return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream) } - return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream) -} - -func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { - mockClient := &mockUpstreamResolver{ - err: dns.ErrTime, - r: new(dns.Msg), - rtt: time.Millisecond, - } - - resolver := &upstreamResolverBase{ - ctx: context.TODO(), - upstreamClient: mockClient, - upstreamTimeout: UpstreamTimeout, - reactivatePeriod: time.Microsecond * 100, - } - addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection - resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())} - - failed := false - resolver.deactivate = func(error) { - failed = true - // After deactivation, make the mock client work again - mockClient.err = nil - } - - reactivated := false - resolver.reactivate = func() { - reactivated = true - } - - resolver.ProbeAvailability(context.TODO()) - - if !failed { - t.Errorf("expected that resolving was deactivated") - return - } - - if !resolver.disabled { - t.Errorf("resolver should be Disabled") - return - } - - time.Sleep(time.Millisecond * 200) - - if !reactivated { - t.Errorf("expected that resolving was reactivated") - return - } - - if resolver.disabled { - t.Errorf("should be enabled") + if r.delay > 0 { + select { + case <-time.After(r.delay): + case <-ctx.Done(): + return nil, c.rtt, ctx.Err() + } } + return r.msg, c.rtt, r.err } func TestUpstreamResolver_Failover(t *testing.T) { @@ -339,9 +286,9 @@ func TestUpstreamResolver_Failover(t *testing.T) { resolver := &upstreamResolverBase{ ctx: ctx, upstreamClient: trackingClient, - upstreamServers: []netip.AddrPort{upstream1, upstream2}, upstreamTimeout: UpstreamTimeout, } + resolver.addRace([]netip.AddrPort{upstream1, upstream2}) var responseMSG *dns.Msg responseWriter := &test.MockResponseWriter{ @@ -421,9 +368,9 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) { resolver := &upstreamResolverBase{ ctx: ctx, upstreamClient: mockClient, - upstreamServers: []netip.AddrPort{upstream}, upstreamTimeout: UpstreamTimeout, } + resolver.addRace([]netip.AddrPort{upstream}) var responseMSG *dns.Msg responseWriter := &test.MockResponseWriter{ @@ -440,6 +387,136 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) { assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL") } +// TestUpstreamResolver_RaceAcrossGroups covers two nameserver groups +// configured for the same domain, with one broken group. The merge+race +// path should answer as fast as the working group and not pay the timeout +// of the broken one on every query. +func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) { + broken := netip.MustParseAddrPort("192.0.2.1:53") + working := netip.MustParseAddrPort("192.0.2.2:53") + successAnswer := "192.0.2.100" + timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")} + + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + // Force the broken upstream to only unblock via timeout / + // cancellation so the assertion below can't pass if races + // were run serially. + broken.String(): {err: timeoutErr, delay: 500 * time.Millisecond}, + working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + }, + rtt: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: mockClient, + upstreamTimeout: 250 * time.Millisecond, + } + resolver.addRace([]netip.AddrPort{broken}) + resolver.addRace([]netip.AddrPort{working}) + + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + start := time.Now() + resolver.ServeDNS(responseWriter, inputMSG) + elapsed := time.Since(start) + + require.NotNil(t, responseMSG, "should write a response") + assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode) + require.NotEmpty(t, responseMSG.Answer) + assert.Contains(t, responseMSG.Answer[0].String(), successAnswer) + // Working group answers in a single RTT; the broken group's + // timeout (100ms) must not block the response. + assert.Less(t, elapsed, 100*time.Millisecond, "race must not wait for broken group's timeout") +} + +// TestUpstreamResolver_AllGroupsFail checks that when every group fails the +// resolver returns SERVFAIL rather than leaking a partial response. +func TestUpstreamResolver_AllGroupsFail(t *testing.T) { + a := netip.MustParseAddrPort("192.0.2.1:53") + b := netip.MustParseAddrPort("192.0.2.2:53") + + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + a.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")}, + b.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")}, + }, + rtt: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: mockClient, + upstreamTimeout: UpstreamTimeout, + } + resolver.addRace([]netip.AddrPort{a}) + resolver.addRace([]netip.AddrPort{b}) + + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA)) + require.NotNil(t, responseMSG) + assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode) +} + +// TestUpstreamResolver_HealthTracking verifies that query-path results are +// recorded into per-upstream health, which is what projects back to +// NSGroupState for status reporting. +func TestUpstreamResolver_HealthTracking(t *testing.T) { + ok := netip.MustParseAddrPort("192.0.2.10:53") + bad := netip.MustParseAddrPort("192.0.2.11:53") + + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + ok.String(): {msg: buildMockResponse(dns.RcodeSuccess, "192.0.2.100")}, + bad.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")}, + }, + rtt: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: mockClient, + upstreamTimeout: UpstreamTimeout, + } + resolver.addRace([]netip.AddrPort{ok, bad}) + + responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }} + resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA)) + + health := resolver.UpstreamHealth() + require.Contains(t, health, ok) + assert.False(t, health[ok].LastOk.IsZero(), "ok upstream should have LastOk set") + assert.Empty(t, health[ok].LastErr) + + // bad upstream was never tried because ok answered first; its health + // should remain unset. + assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers") +} + func TestFormatFailures(t *testing.T) { testCases := []struct { name string @@ -665,10 +742,10 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) { // Verify that a client EDNS0 larger than our MTU-derived limit gets // capped in the outgoing request so the upstream doesn't send a // response larger than our read buffer. - var receivedUDPSize uint16 + var receivedUDPSize atomic.Uint32 udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { if opt := r.IsEdns0(); opt != nil { - receivedUDPSize = opt.UDPSize() + receivedUDPSize.Store(uint32(opt.UDPSize())) } m := new(dns.Msg) m.SetReply(r) @@ -699,7 +776,7 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) { require.NotNil(t, rm) expectedMax := uint16(currentMTU - ipUDPHeaderSize) - assert.Equal(t, expectedMax, receivedUDPSize, + assert.Equal(t, expectedMax, uint16(receivedUDPSize.Load()), "upstream should see capped EDNS0, not the client's 4096") } @@ -874,7 +951,7 @@ func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) { resolver := &upstreamResolverBase{ ctx: ctx, upstreamClient: tracking, - upstreamServers: []netip.AddrPort{upstream1, upstream2}, + upstreamServers: []upstreamRace{{upstream1, upstream2}}, upstreamTimeout: UpstreamTimeout, } diff --git a/client/internal/engine.go b/client/internal/engine.go index 66fe6056b..3bd0d4621 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -512,16 +512,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) - e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool { - for _, routes := range e.routeManager.GetSelectedClientRoutes() { - for _, r := range routes { - if r.Network.Contains(ip) { - return true - } - } - } - return false - }) + e.dnsServer.SetRouteSources(e.routeManager.GetSelectedClientRoutes, e.routeManager.GetActiveClientRoutes) if err = e.wgInterfaceCreate(); err != nil { log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error()) @@ -1386,9 +1377,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.networkSerial = serial - // Test received (upstream) servers for availability right away instead of upon usage. - // If no server of a server group responds this will disable the respective handler and retry later. - go e.dnsServer.ProbeAvailability() return nil } @@ -1932,7 +1920,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { return dnsServer, nil case "ios": - dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS) + dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS) return dnsServer, nil default: diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 907f1f592..839ec14c0 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -53,6 +53,7 @@ type Manager interface { GetRouteSelector() *routeselector.RouteSelector GetClientRoutes() route.HAMap GetSelectedClientRoutes() route.HAMap + GetActiveClientRoutes() route.HAMap GetClientRoutesWithNetID() map[route.NetID][]*route.Route SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -485,6 +486,39 @@ func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap { return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes)) } +// GetActiveClientRoutes returns the subset of selected client routes +// that are currently reachable: the route's peer is Connected and is +// the one actively carrying the route (not just an HA sibling). +func (m *DefaultManager) GetActiveClientRoutes() route.HAMap { + m.mux.Lock() + selected := m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes)) + recorder := m.statusRecorder + m.mux.Unlock() + + if recorder == nil { + return selected + } + + out := make(route.HAMap, len(selected)) + for id, routes := range selected { + for _, r := range routes { + st, err := recorder.GetPeer(r.Peer) + if err != nil { + continue + } + if st.ConnStatus != peer.StatusConnected { + continue + } + if _, hasRoute := st.GetRoutes()[r.Network.String()]; !hasRoute { + continue + } + out[id] = routes + break + } + } + return out +} + // GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { m.mux.Lock() diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 66b5e30dd..937314995 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -19,6 +19,7 @@ type MockManager struct { GetRouteSelectorFunc func() *routeselector.RouteSelector GetClientRoutesFunc func() route.HAMap GetSelectedClientRoutesFunc func() route.HAMap + GetActiveClientRoutesFunc func() route.HAMap GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route StopFunc func(manager *statemanager.Manager) } @@ -78,6 +79,14 @@ func (m *MockManager) GetSelectedClientRoutes() route.HAMap { return nil } +// GetActiveClientRoutes mock implementation of GetActiveClientRoutes from the Manager interface +func (m *MockManager) GetActiveClientRoutes() route.HAMap { + if m.GetActiveClientRoutesFunc != nil { + return m.GetActiveClientRoutesFunc() + } + return nil +} + // GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { if m.GetClientRoutesWithNetIDFunc != nil { diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 33f5ab1b0..bafbb0031 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -162,11 +162,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error { cfg.WgIface = interfaceName c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - hostDNS := []netip.AddrPort{ - netip.MustParseAddrPort("9.9.9.9:53"), - netip.MustParseAddrPort("149.112.112.112:53"), - } - return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile) + return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) } // Stop the internal client and free the resources From e916f12cca508dfea584e7b72cf99a135acebc2b Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Fri, 15 May 2026 19:13:44 +0200 Subject: [PATCH 13/17] [proxy] auth token generation on mapping (#6157) * [management / proxy] auth token generation on mapping * fix tests --- management/internals/shared/grpc/proxy.go | 15 +++--- .../shared/grpc/proxy_snapshot_test.go | 53 +++++++++++++++++++ .../shared/grpc/validate_session_test.go | 14 +++-- 3 files changed, 73 insertions(+), 9 deletions(-) diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 9e5027547..eada2d86a 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -394,6 +394,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec if end > len(mappings) { end = len(mappings) } + for _, m := range mappings[i:end] { + token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL()) + if err != nil { + return fmt.Errorf("generate auth token for service %s: %w", m.Id, err) + } + m.AuthToken = token + } if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ Mapping: mappings[i:end], InitialSyncComplete: end == len(mappings), @@ -425,18 +432,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn * return nil, fmt.Errorf("get services from store: %w", err) } + oidcCfg := s.GetOIDCValidationConfig() var mappings []*proto.ProxyMapping for _, service := range services { if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address { continue } - token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL()) - if err != nil { - return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err) - } - - m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig()) + m := service.ToProtoMapping(rpservice.Create, "", oidcCfg) if !proxyAcceptsMapping(conn, m) { continue } diff --git a/management/internals/shared/grpc/proxy_snapshot_test.go b/management/internals/shared/grpc/proxy_snapshot_test.go index e0c7425c5..68d2ecfd1 100644 --- a/management/internals/shared/grpc/proxy_snapshot_test.go +++ b/management/internals/shared/grpc/proxy_snapshot_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "testing" + "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -172,3 +173,55 @@ func TestSendSnapshot_EmptySnapshot(t *testing.T) { assert.Empty(t, stream.messages[0].Mapping) assert.True(t, stream.messages[0].InitialSyncComplete) } + +type hookingStream struct { + grpc.ServerStream + onSend func(*proto.GetMappingUpdateResponse) +} + +func (s *hookingStream) Send(m *proto.GetMappingUpdateResponse) error { + if s.onSend != nil { + s.onSend(m) + } + return nil +} + +func (s *hookingStream) Context() context.Context { return context.Background() } +func (s *hookingStream) SetHeader(metadata.MD) error { return nil } +func (s *hookingStream) SendHeader(metadata.MD) error { return nil } +func (s *hookingStream) SetTrailer(metadata.MD) {} +func (s *hookingStream) SendMsg(any) error { return nil } +func (s *hookingStream) RecvMsg(any) error { return nil } + +func TestSendSnapshot_TokensRemainValidUnderSlowSend(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 2 + const totalServices = 6 + const ttl = 100 * time.Millisecond + const sendDelay = 200 * time.Millisecond + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + s.tokenTTL = ttl + + var validateErrs []error + stream := &hookingStream{ + onSend: func(resp *proto.GetMappingUpdateResponse) { + for _, m := range resp.Mapping { + if err := s.tokenStore.ValidateAndConsume(m.AuthToken, m.AccountId, m.Id); err != nil { + validateErrs = append(validateErrs, fmt.Errorf("svc %s: %w", m.Id, err)) + } + } + time.Sleep(sendDelay) + }, + } + conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream} + + require.NoError(t, s.sendSnapshot(context.Background(), conn)) + require.Empty(t, validateErrs, + "tokens must remain valid even when batches are sent slowly: lazy per-batch generation guarantees freshness") +} diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index 6cd95f988..7b7ffcfb2 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -326,17 +326,25 @@ func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, return nil, nil } +func (m *testValidateSessionServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error { + return nil +} + type testValidateSessionProxyManager struct{} -func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error { +func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _, _ string, _ *string, _ *proxy.Capabilities) (*proxy.Proxy, error) { + return nil, nil +} + +func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _, _ string) error { return nil } -func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error { +func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ *proxy.Proxy) error { return nil } -func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error { +func (m *testValidateSessionProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error { return nil } From 22e2519d7113dffec718198e54474cc0a6d71c87 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 16 May 2026 22:51:48 +0900 Subject: [PATCH 14/17] [management] Avoid peer IP reallocation when account settings update preserves the network range (#6173) --- management/server/account.go | 37 +++++++++++-- management/server/account_test.go | 90 +++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 3 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 77a46a069..e7b4acaac 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -291,10 +291,15 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, status.NewPermissionDeniedError() } + // Canonicalize the incoming range so a caller-supplied prefix with host bits + // (e.g. 100.64.1.1/16) compares equal to the masked form stored on network.Net. + newSettings.NetworkRange = newSettings.NetworkRange.Masked() + var oldSettings *types.Settings var updateAccountPeers bool var groupChangesAffectPeers bool var reloadReverseProxy bool + var effectiveOldNetworkRange netip.Prefix err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var groupsUpdated bool @@ -308,6 +313,16 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return err } + // No lock: the transaction already holds Settings(Update), and network.Net is + // only mutated by reallocateAccountPeerIPs, which is reachable only through + // this same code path. A Share lock here would extend an unnecessary row lock + // and complicate ordering against updatePeerIPv6InTransaction. + network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return fmt.Errorf("get account network: %w", err) + } + effectiveOldNetworkRange = prefixFromIPNet(network.Net) + if oldSettings.Extra != nil && newSettings.Extra != nil && oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled { approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID) @@ -321,7 +336,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } } - if oldSettings.NetworkRange != newSettings.NetworkRange { + if newSettings.NetworkRange.IsValid() && newSettings.NetworkRange != effectiveOldNetworkRange { if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil { return err } @@ -396,9 +411,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, eventMeta) } - if oldSettings.NetworkRange != newSettings.NetworkRange { + if newSettings.NetworkRange.IsValid() && newSettings.NetworkRange != effectiveOldNetworkRange { eventMeta := map[string]any{ - "old_network_range": oldSettings.NetworkRange.String(), + "old_network_range": effectiveOldNetworkRange.String(), "new_network_range": newSettings.NetworkRange.String(), } am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta) @@ -443,6 +458,22 @@ func ipv6SettingsChanged(old, updated *types.Settings) bool { return !slices.Equal(oldGroups, newGroups) } +// prefixFromIPNet returns the overlay prefix actually allocated on the account +// network, or an invalid prefix if none is set. Settings.NetworkRange is a +// user-facing override that is empty on legacy accounts, so the effective +// range must be read from network.Net to compare against an incoming update. +func prefixFromIPNet(ipNet net.IPNet) netip.Prefix { + if ipNet.IP == nil { + return netip.Prefix{} + } + addr, ok := netip.AddrFromSlice(ipNet.IP) + if !ok { + return netip.Prefix{} + } + ones, _ := ipNet.Mask.Size() + return netip.PrefixFrom(addr.Unmap(), ones) +} + func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { diff --git a/management/server/account_test.go b/management/server/account_test.go index 65b27df49..60720faa6 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3970,6 +3970,96 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi } } +// TestDefaultAccountManager_UpdateAccountSettings_NetworkRangePreserved guards against +// peer IP reallocation when a settings update carries the network range that is already +// in use. Legacy accounts have Settings.NetworkRange unset in the DB while network.Net +// holds the actual allocated overlay; the dashboard backfills the GET response from +// network.Net and echoes the value back on PUT, so the diff must be against the +// effective range to avoid renumbering every peer on an unrelated settings change. +func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangePreserved(t *testing.T) { + manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + ctx := context.Background() + + settings, err := manager.Store.GetAccountSettings(ctx, store.LockingStrengthNone, account.Id) + require.NoError(t, err) + require.False(t, settings.NetworkRange.IsValid(), "precondition: new accounts leave Settings.NetworkRange unset") + + network, err := manager.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, account.Id) + require.NoError(t, err) + require.NotNil(t, network.Net.IP, "precondition: network.Net should be allocated") + addr, ok := netip.AddrFromSlice(network.Net.IP) + require.True(t, ok) + ones, _ := network.Net.Mask.Size() + effective := netip.PrefixFrom(addr.Unmap(), ones) + require.True(t, effective.IsValid()) + + before := map[string]netip.Addr{peer1.ID: peer1.IP, peer2.ID: peer2.IP, peer3.ID: peer3.IP} + + // Round-trip the effective range as if the dashboard echoed back the GET-backfilled value. + _, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + NetworkRange: effective, + Extra: &types.ExtraSettings{}, + }) + require.NoError(t, err) + + peers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "") + require.NoError(t, err) + require.Len(t, peers, len(before)) + for _, p := range peers { + assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change when range matches effective", p.ID) + } + + // Carrying the same range with host bits set must also be a no-op once canonicalized. + hostBitsForm := netip.PrefixFrom(peer1.IP, ones) + require.NotEqual(t, effective, hostBitsForm, "precondition: host-bit form should differ before masking") + _, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + NetworkRange: hostBitsForm, + Extra: &types.ExtraSettings{}, + }) + require.NoError(t, err) + + peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "") + require.NoError(t, err) + for _, p := range peers { + assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change for host-bit-set equivalent range", p.ID) + } + + // Omitting NetworkRange (invalid prefix) must also be a no-op. + _, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + Extra: &types.ExtraSettings{}, + }) + require.NoError(t, err) + + peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "") + require.NoError(t, err) + for _, p := range peers { + assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change when NetworkRange omitted", p.ID) + } + + // Sanity: an actually different range still triggers reallocation. + newRange := netip.MustParsePrefix("100.99.0.0/16") + _, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + NetworkRange: newRange, + Extra: &types.ExtraSettings{}, + }) + require.NoError(t, err) + + peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "") + require.NoError(t, err) + for _, p := range peers { + assert.True(t, newRange.Contains(p.IP), "peer %s should be in new range %s, got %s", p.ID, newRange, p.IP) + assert.NotEqual(t, before[p.ID], p.IP, "peer %s IP should change on real range update", p.ID) + } +} + func TestDefaultAccountManager_UpdateAccountSettings_IPv6EnabledGroups(t *testing.T) { manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t) ctx := context.Background() From 347c5bf317794729a044ce9f866f29e357d386d9 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 16 May 2026 16:29:01 +0200 Subject: [PATCH 15/17] Avoid context cancellation in `cancelPeerRoutines` (#6175) When closing go routines and handling peer disconnect, we should avoid canceling the flow due to parent gRPC context cancellation. This change triggers disconnection handling with a context that is not bound to the parent gRPC cancellation. --- management/internals/shared/grpc/server.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 70024bac6..1d8234304 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -522,10 +522,11 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even } func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) { - unlock := s.acquirePeerLockByUID(ctx, peer.Key) + uncanceledCTX := context.WithoutCancel(ctx) + unlock := s.acquirePeerLockByUID(uncanceledCTX, peer.Key) defer unlock() - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) + s.cancelPeerRoutinesWithoutLock(uncanceledCTX, accountID, peer, streamStartTime) } func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) { From 3f91f49277e1841bdfccda06ae7baa0430e6de2e Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 16 May 2026 23:52:57 +0900 Subject: [PATCH 16/17] Clean up legacy 32-bit and HKCU registry entries on Windows install (#6176) --- client/installer.nsis | 23 ++++++++++++++++++----- client/netbird.wxs | 25 +++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/client/installer.nsis b/client/installer.nsis index 63bff1c5b..3e057df10 100644 --- a/client/installer.nsis +++ b/client/installer.nsis @@ -260,15 +260,23 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}" WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}" -; Create autostart registry entry based on checkbox +; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view +; or HKCU by legacy installers. +DetailPrint "Cleaning legacy 32-bit / HKCU entries..." +DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" +SetRegView 32 +DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" +DeleteRegKey HKLM "${REG_APP_PATH}" +DeleteRegKey HKLM "${UI_REG_APP_PATH}" +DeleteRegKey HKLM "${UNINSTALL_PATH}" +SetRegView 64 + DetailPrint "Autostart enabled: $AutostartEnabled" ${If} $AutostartEnabled == "1" WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"' DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe" ${Else} DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" - ; Legacy: pre-HKLM installs wrote to HKCU; clean that up too. - DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" DetailPrint "Autostart not enabled by user" ${EndIf} @@ -299,11 +307,16 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' DetailPrint "Terminating Netbird UI process..." ExecWait `taskkill /im ${UI_APP_EXE}.exe /f` -; Remove autostart registry entry +; Remove autostart entries from every view a previous installer may have used. DetailPrint "Removing autostart registry entry if exists..." DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" -; Legacy: pre-HKLM installs wrote to HKCU; clean that up too. DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" +SetRegView 32 +DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" +DeleteRegKey HKLM "${REG_APP_PATH}" +DeleteRegKey HKLM "${UI_REG_APP_PATH}" +DeleteRegKey HKLM "${UNINSTALL_PATH}" +SetRegView 64 ; Handle data deletion based on checkbox DetailPrint "Checking if user requested data deletion..." diff --git a/client/netbird.wxs b/client/netbird.wxs index 6f18b63b5..96814ce52 100644 --- a/client/netbird.wxs +++ b/client/netbird.wxs @@ -64,6 +64,13 @@ + + + + + @@ -76,10 +83,28 @@ + + + + + + + + + + + From 7fae703a2741d2b0c7472fecb1c7f82ce5b28357 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Papp?= Date: Mon, 18 May 2026 10:25:18 +0200 Subject: [PATCH 17/17] [client/ui] Port IPv6 toggle and paired default-route filter to Wails UI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Brings two main-side PRs' UI behavior across the Fyne→Wails rewrite: - #5631 (IPv6 overlay support): add "Enable IPv6" row to the polished SettingsNetwork tab; the legacy screens/Settings.tsx already had it, but modules/settings/SettingsNetwork.tsx (the user-visible Settings window) was missing the toggle. - #6150 (mirror v4 exit selection onto v6 pair): replace the literal "0.0.0.0/0" || "::/0" filter in screens/Networks.tsx with an isDefaultRoute() helper that handles the daemon's merged-range display string (e.g. "0.0.0.0/0, ::/0"), so paired v4/v6 exit nodes are classified correctly. --- .../frontend/src/modules/settings/SettingsNetwork.tsx | 6 ++++++ client/ui/frontend/src/screens/Networks.tsx | 11 ++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/client/ui/frontend/src/modules/settings/SettingsNetwork.tsx b/client/ui/frontend/src/modules/settings/SettingsNetwork.tsx index 79bb09837..9eef37dc8 100644 --- a/client/ui/frontend/src/modules/settings/SettingsNetwork.tsx +++ b/client/ui/frontend/src/modules/settings/SettingsNetwork.tsx @@ -45,6 +45,12 @@ export function SettingsNetwork() { label={"Enable Server Routes"} helpText={"Advertise this host's local routes to other peers."} /> + setField("disableIpv6", !v)} + label={"Enable IPv6"} + helpText={"Use IPv6 addressing for the NetBird overlay network."} + /> ); diff --git a/client/ui/frontend/src/screens/Networks.tsx b/client/ui/frontend/src/screens/Networks.tsx index 5fa7a31cc..7e82d71a6 100644 --- a/client/ui/frontend/src/screens/Networks.tsx +++ b/client/ui/frontend/src/screens/Networks.tsx @@ -55,7 +55,7 @@ export default function Networks() { const overlapping = useMemo(() => filterOverlapping(routes), [routes]); const exitNodes = useMemo( - () => routes.filter((r) => r.range === "0.0.0.0/0" || r.range === "::/0"), + () => routes.filter((r) => isDefaultRoute(r.range)), [routes], ); @@ -146,6 +146,15 @@ function NetworkList({ ); } +// range is the merged display string from the daemon, e.g. "0.0.0.0/0", +// "::/0", or "0.0.0.0/0, ::/0" when a v4 exit node has a paired v6 entry. +function isDefaultRoute(range: string): boolean { + return range.split(",").some((part) => { + const trimmed = part.trim(); + return trimmed === "0.0.0.0/0" || trimmed === "::/0"; + }); +} + function filterOverlapping(routes: Network[]): Network[] { const byRange = new Map(); for (const r of routes) {