diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index 8bbc98726..63bdc6dc0 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -31,6 +31,7 @@ type store interface { type proxyManager interface { GetActiveClusterAddresses(ctx context.Context) ([]string, error) + GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) } type Manager struct { @@ -68,8 +69,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d var ret []*domain.Domain // Add connected proxy clusters as free domains. - // The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io"). - allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + // For BYOD accounts, only their own cluster is returned; otherwise shared clusters. + allowList, err := m.getClusterAllowList(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err) return nil, err @@ -112,8 +113,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName return nil, status.NewPermissionDeniedError() } - // Verify the target cluster is in the available clusters - allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + // Verify the target cluster is in the available clusters for this account + allowList, err := m.getClusterAllowList(ctx, accountID) if err != nil { return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err) } @@ -259,7 +260,7 @@ func (m Manager) GetClusterDomains() []string { // For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain. // For custom domains, the cluster is determined by checking the registered custom domain's target cluster. func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) { - allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + allowList, err := m.getClusterAllowList(ctx, accountID) if err != nil { return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err) } @@ -284,6 +285,17 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain) } +func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) { + byodAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID) + if err != nil { + return nil, fmt.Errorf("get BYOD cluster addresses: %w", err) + } + if len(byodAddresses) > 0 { + return byodAddresses, nil + } + return m.proxyManager.GetActiveClusterAddresses(ctx) +} + func extractClusterFromCustomDomains(domain string, customDomains []*domain.Domain) (string, bool) { for _, customDomain := range customDomains { if strings.HasSuffix(domain, "."+customDomain.Domain) { diff --git a/management/internals/modules/reverseproxy/domain/manager/manager_test.go b/management/internals/modules/reverseproxy/domain/manager/manager_test.go new file mode 100644 index 000000000..cf317ca4c --- /dev/null +++ b/management/internals/modules/reverseproxy/domain/manager/manager_test.go @@ -0,0 +1,144 @@ +package manager + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockProxyManager struct { + getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error) + getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error) +} + +func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { + if m.getActiveClusterAddressesFunc != nil { + return m.getActiveClusterAddressesFunc(ctx) + } + return nil, nil +} + +func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + if m.getActiveClusterAddressesForAccountFunc != nil { + return m.getActiveClusterAddressesForAccountFunc(ctx, accountID) + } + return nil, nil +} + +func TestGetClusterAllowList_BYODProxy(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) { + assert.Equal(t, "acc-123", accID) + return []string{"byod.example.com"}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + t.Fatal("should not call GetActiveClusterAddresses when BYOD addresses exist") + return nil, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"byod.example.com"}, result) +} + +func TestGetClusterAllowList_NoBYOD_FallbackToShared(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return nil, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result) +} + +func TestGetClusterAllowList_BYODError_FallbackToShared(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return nil, errors.New("db error") + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return []string{"eu.proxy.netbird.io"}, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"eu.proxy.netbird.io"}, result) +} + +func TestGetClusterAllowList_BYODEmptySlice_FallbackToShared(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return []string{}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return []string{"eu.proxy.netbird.io"}, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"eu.proxy.netbird.io"}, result) +} + +func TestExtractClusterFromFreeDomain(t *testing.T) { + clusters := []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"} + + tests := []struct { + name string + domain string + wantCluster string + wantOK bool + }{ + { + name: "matches EU cluster", + domain: "myapp.abc123.eu.proxy.netbird.io", + wantCluster: "eu.proxy.netbird.io", + wantOK: true, + }, + { + name: "matches US cluster", + domain: "myapp.xyz789.us.proxy.netbird.io", + wantCluster: "us.proxy.netbird.io", + wantOK: true, + }, + { + name: "no match - custom domain", + domain: "app.example.com", + wantOK: false, + }, + { + name: "no match - partial cluster name", + domain: "proxy.netbird.io", + wantOK: false, + }, + { + name: "exact cluster name - no prefix", + domain: "eu.proxy.netbird.io", + wantOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cluster, ok := ExtractClusterFromFreeDomain(tt.domain, clusters) + assert.Equal(t, tt.wantOK, ok) + if tt.wantOK { + assert.Equal(t, tt.wantCluster, cluster) + } + }) + } +} diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go index 15f2f9f54..b37719d18 100644 --- a/management/internals/modules/reverseproxy/proxy/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -11,11 +11,16 @@ import ( // Manager defines the interface for proxy operations type Manager interface { - Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string) error Disconnect(ctx context.Context, proxyID string) error Heartbeat(ctx context.Context, proxyID string) error GetActiveClusterAddresses(ctx context.Context) ([]string, error) + GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error + GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) + CountAccountProxies(ctx context.Context, accountID string) (int64, error) + IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) + DeleteProxy(ctx context.Context, proxyID string) error } // OIDCValidationConfig contains the OIDC configuration needed for token validation. diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager.go b/management/internals/modules/reverseproxy/proxy/manager/manager.go index 4c0964b5c..8b5a56bac 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager.go @@ -13,9 +13,15 @@ import ( // store defines the interface for proxy persistence operations type store interface { SaveProxy(ctx context.Context, p *proxy.Proxy) error + DisconnectProxy(ctx context.Context, proxyID string) error UpdateProxyHeartbeat(ctx context.Context, proxyID string) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) + GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error + GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) + CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) + IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) + DeleteProxy(ctx context.Context, proxyID string) error } // Manager handles all proxy operations @@ -38,15 +44,16 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) { } // Connect registers a new proxy connection in the database -func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +func (m *Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string) error { now := time.Now() p := &proxy.Proxy{ ID: proxyID, ClusterAddress: clusterAddress, IPAddress: ipAddress, + AccountID: accountID, LastSeen: now, ConnectedAt: &now, - Status: "connected", + Status: proxy.StatusConnected, } if err := m.store.SaveProxy(ctx, p); err != nil { @@ -64,16 +71,8 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress } // Disconnect marks a proxy as disconnected in the database -func (m Manager) Disconnect(ctx context.Context, proxyID string) error { - now := time.Now() - p := &proxy.Proxy{ - ID: proxyID, - Status: "disconnected", - DisconnectedAt: &now, - LastSeen: now, - } - - if err := m.store.SaveProxy(ctx, p); err != nil { +func (m *Manager) Disconnect(ctx context.Context, proxyID string) error { + if err := m.store.DisconnectProxy(ctx, proxyID); err != nil { log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err) return err } @@ -86,7 +85,7 @@ func (m Manager) Disconnect(ctx context.Context, proxyID string) error { } // Heartbeat updates the proxy's last seen timestamp -func (m Manager) Heartbeat(ctx context.Context, proxyID string) error { +func (m *Manager) Heartbeat(ctx context.Context, proxyID string) error { if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil { log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err) return err @@ -96,7 +95,7 @@ func (m Manager) Heartbeat(ctx context.Context, proxyID string) error { } // GetActiveClusterAddresses returns all unique cluster addresses for active proxies -func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { +func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { addresses, err := m.store.GetActiveProxyClusterAddresses(ctx) if err != nil { log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err) @@ -106,10 +105,44 @@ func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error } // CleanupStale removes proxies that haven't sent heartbeat in the specified duration -func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { +func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil { log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err) return err } return nil } + +func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err) + return nil, err + } + return addresses, nil +} + +func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) { + return m.store.GetProxyByAccountID(ctx, accountID) +} + +func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) { + return m.store.CountProxiesByAccountID(ctx, accountID) +} + +func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) { + conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID) + if err != nil { + return false, err + } + return !conflicting, nil +} + +func (m *Manager) DeleteProxy(ctx context.Context, proxyID string) error { + if err := m.store.DeleteProxy(ctx, proxyID); err != nil { + log.WithContext(ctx).Errorf("failed to delete proxy %s: %v", proxyID, err) + return err + } + return nil +} + diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager_test.go b/management/internals/modules/reverseproxy/proxy/manager/manager_test.go new file mode 100644 index 000000000..3c3bf8fa4 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxy/manager/manager_test.go @@ -0,0 +1,321 @@ +package manager + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" +) + +type mockStore struct { + saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error + disconnectProxyFunc func(ctx context.Context, proxyID string) error + updateProxyHeartbeatFunc func(ctx context.Context, proxyID string) error + getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error) + getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error) + cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error + getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error) + countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error) + isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error) + deleteProxyFunc func(ctx context.Context, proxyID string) error +} + +func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { + if m.saveProxyFunc != nil { + return m.saveProxyFunc(ctx, p) + } + return nil +} +func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID string) error { + if m.disconnectProxyFunc != nil { + return m.disconnectProxyFunc(ctx, proxyID) + } + return nil +} +func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error { + if m.updateProxyHeartbeatFunc != nil { + return m.updateProxyHeartbeatFunc(ctx, proxyID) + } + return nil +} +func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) { + if m.getActiveProxyClusterAddressesFunc != nil { + return m.getActiveProxyClusterAddressesFunc(ctx) + } + return nil, nil +} +func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + if m.getActiveProxyClusterAddressesForAccFunc != nil { + return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID) + } + return nil, nil +} +func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error { + if m.cleanupStaleProxiesFunc != nil { + return m.cleanupStaleProxiesFunc(ctx, d) + } + return nil +} +func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) { + if m.getProxyByAccountIDFunc != nil { + return m.getProxyByAccountIDFunc(ctx, accountID) + } + return nil, nil +} +func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) { + if m.countProxiesByAccountIDFunc != nil { + return m.countProxiesByAccountIDFunc(ctx, accountID) + } + return 0, nil +} +func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) { + if m.isClusterAddressConflictingFunc != nil { + return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID) + } + return false, nil +} +func (m *mockStore) DeleteProxy(ctx context.Context, proxyID string) error { + if m.deleteProxyFunc != nil { + return m.deleteProxyFunc(ctx, proxyID) + } + return nil +} + +func newTestManager(s store) *Manager { + meter := noop.NewMeterProvider().Meter("test") + m, err := NewManager(s, meter) + if err != nil { + panic(err) + } + return m +} + +func TestConnect_WithAccountID(t *testing.T) { + accountID := "acc-123" + + var savedProxy *proxy.Proxy + s := &mockStore{ + saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error { + savedProxy = p + return nil + }, + } + + mgr := newTestManager(s) + err := mgr.Connect(context.Background(), "proxy-1", "cluster.example.com", "10.0.0.1", &accountID) + require.NoError(t, err) + + require.NotNil(t, savedProxy) + assert.Equal(t, "proxy-1", savedProxy.ID) + assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress) + assert.Equal(t, "10.0.0.1", savedProxy.IPAddress) + assert.Equal(t, &accountID, savedProxy.AccountID) + assert.Equal(t, proxy.StatusConnected, savedProxy.Status) + assert.NotNil(t, savedProxy.ConnectedAt) +} + +func TestConnect_WithoutAccountID(t *testing.T) { + var savedProxy *proxy.Proxy + s := &mockStore{ + saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error { + savedProxy = p + return nil + }, + } + + mgr := newTestManager(s) + err := mgr.Connect(context.Background(), "proxy-1", "eu.proxy.netbird.io", "10.0.0.1", nil) + require.NoError(t, err) + + require.NotNil(t, savedProxy) + assert.Nil(t, savedProxy.AccountID) + assert.Equal(t, proxy.StatusConnected, savedProxy.Status) +} + +func TestConnect_StoreError(t *testing.T) { + s := &mockStore{ + saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error { + return errors.New("db error") + }, + } + + mgr := newTestManager(s) + err := mgr.Connect(context.Background(), "proxy-1", "cluster.example.com", "10.0.0.1", nil) + assert.Error(t, err) +} + +func TestIsClusterAddressAvailable(t *testing.T) { + tests := []struct { + name string + conflicting bool + storeErr error + wantResult bool + wantErr bool + }{ + { + name: "available - no conflict", + conflicting: false, + wantResult: true, + }, + { + name: "not available - conflict exists", + conflicting: true, + wantResult: false, + }, + { + name: "store error", + storeErr: errors.New("db error"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &mockStore{ + isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) { + return tt.conflicting, tt.storeErr + }, + } + + mgr := newTestManager(s) + result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123") + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantResult, result) + }) + } +} + +func TestCountAccountProxies(t *testing.T) { + tests := []struct { + name string + count int64 + storeErr error + wantCount int64 + wantErr bool + }{ + { + name: "no proxies", + count: 0, + wantCount: 0, + }, + { + name: "one proxy", + count: 1, + wantCount: 1, + }, + { + name: "store error", + storeErr: errors.New("db error"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &mockStore{ + countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) { + return tt.count, tt.storeErr + }, + } + + mgr := newTestManager(s) + count, err := mgr.CountAccountProxies(context.Background(), "acc-123") + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantCount, count) + }) + } +} + +func TestGetAccountProxy(t *testing.T) { + accountID := "acc-123" + + t.Run("found", func(t *testing.T) { + expected := &proxy.Proxy{ + ID: "proxy-1", + ClusterAddress: "byod.example.com", + AccountID: &accountID, + Status: proxy.StatusConnected, + } + s := &mockStore{ + getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) { + assert.Equal(t, accountID, accID) + return expected, nil + }, + } + + mgr := newTestManager(s) + p, err := mgr.GetAccountProxy(context.Background(), accountID) + require.NoError(t, err) + assert.Equal(t, expected, p) + }) + + t.Run("not found", func(t *testing.T) { + s := &mockStore{ + getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) { + return nil, errors.New("not found") + }, + } + + mgr := newTestManager(s) + _, err := mgr.GetAccountProxy(context.Background(), accountID) + assert.Error(t, err) + }) +} + +func TestDeleteProxy(t *testing.T) { + t.Run("success", func(t *testing.T) { + var deletedID string + s := &mockStore{ + deleteProxyFunc: func(_ context.Context, proxyID string) error { + deletedID = proxyID + return nil + }, + } + + mgr := newTestManager(s) + err := mgr.DeleteProxy(context.Background(), "proxy-1") + require.NoError(t, err) + assert.Equal(t, "proxy-1", deletedID) + }) + + t.Run("store error", func(t *testing.T) { + s := &mockStore{ + deleteProxyFunc: func(_ context.Context, _ string) error { + return errors.New("db error") + }, + } + + mgr := newTestManager(s) + err := mgr.DeleteProxy(context.Background(), "proxy-1") + assert.Error(t, err) + }) +} + +func TestGetActiveClusterAddressesForAccount(t *testing.T) { + expected := []string{"byod.example.com"} + s := &mockStore{ + getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) { + assert.Equal(t, "acc-123", accID) + return expected, nil + }, + } + + mgr := newTestManager(s) + result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, expected, result) +} diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go index d9645ba88..bad30df04 100644 --- a/management/internals/modules/reverseproxy/proxy/manager_mock.go +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -51,17 +51,17 @@ func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interfac } // Connect mocks base method. -func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress) + ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, accountID) ret0, _ := ret[0].(error) return ret0 } // Connect indicates an expected call of Connect. -func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, accountID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, accountID) } // Disconnect mocks base method. @@ -93,6 +93,21 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx) } +// GetActiveClusterAddressesForAccount mocks base method. +func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveClusterAddressesForAccount indicates an expected call of GetActiveClusterAddressesForAccount. +func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID) +} + // Heartbeat mocks base method. func (m *MockManager) Heartbeat(ctx context.Context, proxyID string) error { m.ctrl.T.Helper() @@ -107,6 +122,65 @@ func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID interface{}) *gomock.C return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID) } +// GetAccountProxy mocks base method. +func (m *MockManager) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccountProxy", ctx, accountID) + ret0, _ := ret[0].(*Proxy) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccountProxy indicates an expected call of GetAccountProxy. +func (mr *MockManagerMockRecorder) GetAccountProxy(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountProxy", reflect.TypeOf((*MockManager)(nil).GetAccountProxy), ctx, accountID) +} + +// CountAccountProxies mocks base method. +func (m *MockManager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountAccountProxies", ctx, accountID) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountAccountProxies indicates an expected call of CountAccountProxies. +func (mr *MockManagerMockRecorder) CountAccountProxies(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountProxies", reflect.TypeOf((*MockManager)(nil).CountAccountProxies), ctx, accountID) +} + +// IsClusterAddressAvailable mocks base method. +func (m *MockManager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsClusterAddressAvailable", ctx, clusterAddress, accountID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsClusterAddressAvailable indicates an expected call of IsClusterAddressAvailable. +func (mr *MockManagerMockRecorder) IsClusterAddressAvailable(ctx, clusterAddress, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressAvailable", reflect.TypeOf((*MockManager)(nil).IsClusterAddressAvailable), ctx, clusterAddress, accountID) +} + +// DeleteProxy mocks base method. +func (m *MockManager) DeleteProxy(ctx context.Context, proxyID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteProxy", ctx, proxyID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteProxy indicates an expected call of DeleteProxy. +func (mr *MockManagerMockRecorder) DeleteProxy(ctx, proxyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteProxy", reflect.TypeOf((*MockManager)(nil).DeleteProxy), ctx, proxyID) +} + // MockController is a mock of Controller interface. type MockController struct { ctrl *gomock.Controller diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go index 699e1ed02..96d1142e7 100644 --- a/management/internals/modules/reverseproxy/proxy/proxy.go +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -2,11 +2,17 @@ package proxy import "time" +const ( + StatusConnected = "connected" + StatusDisconnected = "disconnected" +) + // Proxy represents a reverse proxy instance type Proxy struct { ID string `gorm:"primaryKey;type:varchar(255)"` ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"` IPAddress string `gorm:"type:varchar(45)"` + AccountID *string `gorm:"type:varchar(255);uniqueIndex:idx_proxy_account_id_unique"` LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"` ConnectedAt *time.Time DisconnectedAt *time.Time diff --git a/management/internals/modules/reverseproxy/proxytoken/handler.go b/management/internals/modules/reverseproxy/proxytoken/handler.go new file mode 100644 index 000000000..591c465cb --- /dev/null +++ b/management/internals/modules/reverseproxy/proxytoken/handler.go @@ -0,0 +1,184 @@ +package proxytoken + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/gorilla/mux" + + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +type handler struct { + store store.Store + permissionsManager permissions.Manager +} + +func RegisterEndpoints(s store.Store, permissionsManager permissions.Manager, router *mux.Router) { + h := &handler{store: s, permissionsManager: permissionsManager} + router.HandleFunc("/reverse-proxies/proxy-tokens", h.listTokens).Methods("GET", "OPTIONS") + router.HandleFunc("/reverse-proxies/proxy-tokens", h.createToken).Methods("POST", "OPTIONS") + router.HandleFunc("/reverse-proxies/proxy-tokens/{tokenId}", h.revokeToken).Methods("DELETE", "OPTIONS") +} + +func (h *handler) createToken(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create) + if err != nil { + util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w) + return + } + if !ok { + util.WriteErrorResponse("permission denied", http.StatusForbidden, w) + return + } + + var req api.ProxyTokenRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + if req.Name == "" || len(req.Name) > 255 { + util.WriteErrorResponse("name is required and must be at most 255 characters", http.StatusBadRequest, w) + return + } + + var expiresIn time.Duration + if req.ExpiresIn != nil && *req.ExpiresIn > 0 { + expiresIn = time.Duration(*req.ExpiresIn) * time.Second + } + + accountID := userAuth.AccountId + generated, err := types.CreateNewProxyAccessToken(req.Name, expiresIn, &accountID, userAuth.UserId) + if err != nil { + util.WriteErrorResponse("failed to generate token", http.StatusInternalServerError, w) + return + } + + if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil { + util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w) + return + } + + resp := toProxyTokenCreatedResponse(generated) + util.WriteJSONObject(r.Context(), w, resp) +} + +func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read) + if err != nil { + util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w) + return + } + if !ok { + util.WriteErrorResponse("permission denied", http.StatusForbidden, w) + return + } + + tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId) + if err != nil { + util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w) + return + } + + resp := make([]api.ProxyToken, 0, len(tokens)) + for _, token := range tokens { + resp = append(resp, toProxyTokenResponse(token)) + } + + util.WriteJSONObject(r.Context(), w, resp) +} + +func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete) + if err != nil { + util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w) + return + } + if !ok { + util.WriteErrorResponse("permission denied", http.StatusForbidden, w) + return + } + + tokenID := mux.Vars(r)["tokenId"] + if tokenID == "" { + util.WriteErrorResponse("token ID is required", http.StatusBadRequest, w) + return + } + + token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID) + if err != nil { + util.WriteErrorResponse("token not found", http.StatusNotFound, w) + return + } + + if token.AccountID == nil || *token.AccountID != userAuth.AccountId { + util.WriteErrorResponse("token not found", http.StatusNotFound, w) + return + } + + if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil { + util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} + +func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken { + resp := api.ProxyToken{ + Id: token.ID, + Name: token.Name, + Revoked: token.Revoked, + } + if !token.CreatedAt.IsZero() { + resp.CreatedAt = token.CreatedAt + } + if token.ExpiresAt != nil { + resp.ExpiresAt = token.ExpiresAt + } + if token.LastUsed != nil { + resp.LastUsed = token.LastUsed + } + return resp +} + +func toProxyTokenCreatedResponse(generated *types.ProxyAccessTokenGenerated) api.ProxyTokenCreated { + base := toProxyTokenResponse(&generated.ProxyAccessToken) + plainToken := string(generated.PlainToken) + return api.ProxyTokenCreated{ + Id: base.Id, + Name: base.Name, + CreatedAt: base.CreatedAt, + ExpiresAt: base.ExpiresAt, + LastUsed: base.LastUsed, + Revoked: base.Revoked, + PlainToken: plainToken, + } +} diff --git a/management/internals/modules/reverseproxy/proxytoken/handler_test.go b/management/internals/modules/reverseproxy/proxytoken/handler_test.go new file mode 100644 index 000000000..a28752909 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxytoken/handler_test.go @@ -0,0 +1,275 @@ +package proxytoken + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func authContext(accountID, userID string) context.Context { + return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{ + AccountId: accountID, + UserId: userID, + }) +} + +func TestCreateToken_AccountScoped(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + accountID := "acc-123" + var savedToken *types.ProxyAccessToken + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, token *types.ProxyAccessToken) error { + savedToken = token + return nil + }, + ) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + body := `{"name": "my-token"}` + req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body)) + req = req.WithContext(authContext(accountID, "user-1")) + w := httptest.NewRecorder() + + h.createToken(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + var resp api.ProxyTokenCreated + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + + assert.NotEmpty(t, resp.PlainToken) + assert.Equal(t, "my-token", resp.Name) + assert.False(t, resp.Revoked) + + require.NotNil(t, savedToken) + require.NotNil(t, savedToken.AccountID) + assert.Equal(t, accountID, *savedToken.AccountID) + assert.Equal(t, "user-1", savedToken.CreatedBy) +} + +func TestCreateToken_WithExpiration(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var savedToken *types.ProxyAccessToken + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, token *types.ProxyAccessToken) error { + savedToken = token + return nil + }, + ) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + body := `{"name": "expiring-token", "expires_in": 3600}` + req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body)) + req = req.WithContext(authContext("acc-123", "user-1")) + w := httptest.NewRecorder() + + h.createToken(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + require.NotNil(t, savedToken) + require.NotNil(t, savedToken.ExpiresAt) + assert.True(t, savedToken.ExpiresAt.After(time.Now())) +} + +func TestCreateToken_EmptyName(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil) + + h := &handler{ + permissionsManager: permsMgr, + } + + body := `{"name": ""}` + req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body)) + req = req.WithContext(authContext("acc-123", "user-1")) + w := httptest.NewRecorder() + + h.createToken(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestCreateToken_PermissionDenied(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil) + + h := &handler{ + permissionsManager: permsMgr, + } + + body := `{"name": "test"}` + req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body)) + req = req.WithContext(authContext("acc-123", "user-1")) + w := httptest.NewRecorder() + + h.createToken(w, req) + assert.Equal(t, http.StatusForbidden, w.Code) +} + +func TestListTokens(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + accountID := "acc-123" + now := time.Now() + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetProxyAccessTokensByAccountID(gomock.Any(), store.LockingStrengthNone, accountID).Return([]*types.ProxyAccessToken{ + {ID: "tok-1", Name: "token-1", AccountID: &accountID, CreatedAt: now, Revoked: false}, + {ID: "tok-2", Name: "token-2", AccountID: &accountID, CreatedAt: now, Revoked: true}, + }, nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("GET", "/reverse-proxies/proxy-tokens", nil) + req = req.WithContext(authContext(accountID, "user-1")) + w := httptest.NewRecorder() + + h.listTokens(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + var resp []api.ProxyToken + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + require.Len(t, resp, 2) + assert.Equal(t, "tok-1", resp[0].Id) + assert.False(t, resp[0].Revoked) + assert.Equal(t, "tok-2", resp[1].Id) + assert.True(t, resp[1].Revoked) +} + +func TestRevokeToken_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + accountID := "acc-123" + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{ + ID: "tok-1", + Name: "test-token", + AccountID: &accountID, + }, nil) + mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil) + req = req.WithContext(authContext(accountID, "user-1")) + req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"}) + w := httptest.NewRecorder() + + h.revokeToken(w, req) + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestRevokeToken_WrongAccount(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + otherAccount := "acc-other" + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{ + ID: "tok-1", + AccountID: &otherAccount, + }, nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil) + req = req.WithContext(authContext("acc-123", "user-1")) + req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"}) + w := httptest.NewRecorder() + + h.revokeToken(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) +} + +func TestRevokeToken_ManagementWideToken(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{ + ID: "tok-1", + AccountID: nil, + }, nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil) + req = req.WithContext(authContext("acc-123", "user-1")) + req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"}) + w := httptest.NewRecorder() + + h.revokeToken(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) +} diff --git a/management/internals/modules/reverseproxy/selfhostedproxy/handler.go b/management/internals/modules/reverseproxy/selfhostedproxy/handler.go new file mode 100644 index 000000000..9eb1e885f --- /dev/null +++ b/management/internals/modules/reverseproxy/selfhostedproxy/handler.go @@ -0,0 +1,150 @@ +package selfhostedproxy + +import ( + "net/http" + + "github.com/gorilla/mux" + + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" +) + +// ProxyDisconnector can force-disconnect a connected proxy's gRPC stream. +type ProxyDisconnector interface { + ForceDisconnect(proxyID string) +} + +type handler struct { + proxyMgr proxy.Manager + serviceMgr rpservice.Manager + permissionsManager permissions.Manager + disconnector ProxyDisconnector +} + +func RegisterEndpoints(proxyMgr proxy.Manager, serviceMgr rpservice.Manager, permissionsManager permissions.Manager, disconnector ProxyDisconnector, router *mux.Router) { + h := &handler{ + proxyMgr: proxyMgr, + serviceMgr: serviceMgr, + permissionsManager: permissionsManager, + disconnector: disconnector, + } + router.HandleFunc("/reverse-proxies/self-hosted-proxies", h.listProxies).Methods("GET", "OPTIONS") + router.HandleFunc("/reverse-proxies/self-hosted-proxies/{proxyId}", h.deleteProxy).Methods("DELETE", "OPTIONS") +} + +func (h *handler) listProxies(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read) + if err != nil { + util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w) + return + } + if !ok { + util.WriteErrorResponse("permission denied", http.StatusForbidden, w) + return + } + + p, err := h.proxyMgr.GetAccountProxy(r.Context(), userAuth.AccountId) + if err != nil { + if isNotFound(err) { + util.WriteJSONObject(r.Context(), w, []api.SelfHostedProxy{}) + return + } + util.WriteErrorResponse("failed to get proxy", http.StatusInternalServerError, w) + return + } + + serviceCount := 0 + services, err := h.serviceMgr.GetAccountServices(r.Context(), userAuth.AccountId) + if err == nil { + for _, svc := range services { + if svc.ProxyCluster == p.ClusterAddress { + serviceCount++ + } + } + } + + resp := []api.SelfHostedProxy{toSelfHostedProxyResponse(p, serviceCount)} + util.WriteJSONObject(r.Context(), w, resp) +} + +func (h *handler) deleteProxy(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete) + if err != nil { + util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w) + return + } + if !ok { + util.WriteErrorResponse("permission denied", http.StatusForbidden, w) + return + } + + proxyID := mux.Vars(r)["proxyId"] + if proxyID == "" { + util.WriteErrorResponse("proxy ID is required", http.StatusBadRequest, w) + return + } + + p, err := h.proxyMgr.GetAccountProxy(r.Context(), userAuth.AccountId) + if err != nil { + util.WriteErrorResponse("proxy not found", http.StatusNotFound, w) + return + } + + if p.ID != proxyID { + util.WriteErrorResponse("proxy not found", http.StatusNotFound, w) + return + } + + if h.disconnector != nil { + h.disconnector.ForceDisconnect(proxyID) + } + + if err := h.proxyMgr.DeleteProxy(r.Context(), proxyID); err != nil { + util.WriteErrorResponse("failed to delete proxy", http.StatusInternalServerError, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} + +func isNotFound(err error) bool { + e, ok := status.FromError(err) + return ok && e.Type() == status.NotFound +} + +func toSelfHostedProxyResponse(p *proxy.Proxy, serviceCount int) api.SelfHostedProxy { + st := api.SelfHostedProxyStatus(p.Status) + resp := api.SelfHostedProxy{ + Id: p.ID, + ClusterAddress: p.ClusterAddress, + Status: st, + LastSeen: p.LastSeen, + ServiceCount: serviceCount, + } + if p.IPAddress != "" { + resp.IpAddress = &p.IPAddress + } + if p.ConnectedAt != nil { + resp.ConnectedAt = p.ConnectedAt + } + return resp +} diff --git a/management/internals/modules/reverseproxy/selfhostedproxy/handler_test.go b/management/internals/modules/reverseproxy/selfhostedproxy/handler_test.go new file mode 100644 index 000000000..00515d24d --- /dev/null +++ b/management/internals/modules/reverseproxy/selfhostedproxy/handler_test.go @@ -0,0 +1,220 @@ +package selfhostedproxy + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" +) + +type mockDisconnector struct { + disconnectedIDs []string +} + +func (m *mockDisconnector) ForceDisconnect(proxyID string) { + m.disconnectedIDs = append(m.disconnectedIDs, proxyID) +} + +func authContext(accountID, userID string) context.Context { + return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{ + AccountId: accountID, + UserId: userID, + }) +} + +func TestListProxies_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + accountID := "acc-123" + now := time.Now() + connAt := now.Add(-1 * time.Hour) + + proxyMgr := proxy.NewMockManager(ctrl) + proxyMgr.EXPECT().GetAccountProxy(gomock.Any(), accountID).Return(&proxy.Proxy{ + ID: "proxy-1", + ClusterAddress: "byod.example.com", + IPAddress: "10.0.0.1", + AccountID: &accountID, + Status: proxy.StatusConnected, + LastSeen: now, + ConnectedAt: &connAt, + }, nil) + + serviceMgr := rpservice.NewMockManager(ctrl) + serviceMgr.EXPECT().GetAccountServices(gomock.Any(), accountID).Return([]*rpservice.Service{ + {ProxyCluster: "byod.example.com"}, + {ProxyCluster: "byod.example.com"}, + {ProxyCluster: "other.cluster.com"}, + }, nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil) + + h := &handler{ + proxyMgr: proxyMgr, + serviceMgr: serviceMgr, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("GET", "/reverse-proxies/self-hosted-proxies", nil) + req = req.WithContext(authContext(accountID, "user-1")) + w := httptest.NewRecorder() + + h.listProxies(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + var resp []api.SelfHostedProxy + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + require.Len(t, resp, 1) + assert.Equal(t, "proxy-1", resp[0].Id) + assert.Equal(t, "byod.example.com", resp[0].ClusterAddress) + assert.Equal(t, 2, resp[0].ServiceCount) + assert.Equal(t, api.SelfHostedProxyStatus(proxy.StatusConnected), resp[0].Status) +} + +func TestListProxies_NoProxy(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + proxyMgr := proxy.NewMockManager(ctrl) + proxyMgr.EXPECT().GetAccountProxy(gomock.Any(), "acc-123").Return(nil, status.Errorf(status.NotFound, "not found")) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Read).Return(true, nil) + + h := &handler{ + proxyMgr: proxyMgr, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("GET", "/reverse-proxies/self-hosted-proxies", nil) + req = req.WithContext(authContext("acc-123", "user-1")) + w := httptest.NewRecorder() + + h.listProxies(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + var resp []api.SelfHostedProxy + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + assert.Empty(t, resp) +} + +func TestListProxies_PermissionDenied(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Read).Return(false, nil) + + h := &handler{ + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("GET", "/reverse-proxies/self-hosted-proxies", nil) + req = req.WithContext(authContext("acc-123", "user-1")) + w := httptest.NewRecorder() + + h.listProxies(w, req) + assert.Equal(t, http.StatusForbidden, w.Code) +} + +func TestDeleteProxy_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + accountID := "acc-123" + disconnector := &mockDisconnector{} + + proxyMgr := proxy.NewMockManager(ctrl) + proxyMgr.EXPECT().GetAccountProxy(gomock.Any(), accountID).Return(&proxy.Proxy{ + ID: "proxy-1", + AccountID: &accountID, + Status: proxy.StatusConnected, + }, nil) + proxyMgr.EXPECT().DeleteProxy(gomock.Any(), "proxy-1").Return(nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil) + + h := &handler{ + proxyMgr: proxyMgr, + permissionsManager: permsMgr, + disconnector: disconnector, + } + + req := httptest.NewRequest("DELETE", "/reverse-proxies/self-hosted-proxies/proxy-1", nil) + req = req.WithContext(authContext(accountID, "user-1")) + req = mux.SetURLVars(req, map[string]string{"proxyId": "proxy-1"}) + w := httptest.NewRecorder() + + h.deleteProxy(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, disconnector.disconnectedIDs, "proxy-1") +} + +func TestDeleteProxy_WrongProxyID(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + accountID := "acc-123" + + proxyMgr := proxy.NewMockManager(ctrl) + proxyMgr.EXPECT().GetAccountProxy(gomock.Any(), accountID).Return(&proxy.Proxy{ + ID: "proxy-1", + AccountID: &accountID, + }, nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil) + + h := &handler{ + proxyMgr: proxyMgr, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("DELETE", "/reverse-proxies/self-hosted-proxies/proxy-other", nil) + req = req.WithContext(authContext(accountID, "user-1")) + req = mux.SetURLVars(req, map[string]string{"proxyId": "proxy-other"}) + w := httptest.NewRecorder() + + h.deleteProxy(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) +} + +func TestDeleteProxy_PermissionDenied(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(false, nil) + + h := &handler{ + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("DELETE", "/reverse-proxies/self-hosted-proxies/proxy-1", nil) + req = req.WithContext(authContext("acc-123", "user-1")) + req = mux.SetURLVars(req, map[string]string{"proxyId": "proxy-1"}) + w := httptest.NewRecorder() + + h.deleteProxy(w, req) + assert.Equal(t, http.StatusForbidden, w.Code) +} diff --git a/management/internals/modules/reverseproxy/service/interface.go b/management/internals/modules/reverseproxy/service/interface.go index b420f22a8..d66b87eb8 100644 --- a/management/internals/modules/reverseproxy/service/interface.go +++ b/management/internals/modules/reverseproxy/service/interface.go @@ -25,4 +25,5 @@ type Manager interface { RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error StartExposeReaper(ctx context.Context) + GetServiceByDomain(ctx context.Context, domain string) (*Service, error) } diff --git a/management/internals/modules/reverseproxy/service/interface_mock.go b/management/internals/modules/reverseproxy/service/interface_mock.go index 727b2c7de..e0b34879c 100644 --- a/management/internals/modules/reverseproxy/service/interface_mock.go +++ b/management/internals/modules/reverseproxy/service/interface_mock.go @@ -122,6 +122,21 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID) } +// GetServiceByDomain mocks base method. +func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain) + ret0, _ := ret[0].(*Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServiceByDomain indicates an expected call of GetServiceByDomain. +func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain) +} + // GetGlobalServices mocks base method. func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index 56a1fc98a..ecee9117d 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -627,6 +627,10 @@ func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]* return services, nil } +func (m *Manager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + return m.store.GetServiceByDomain(ctx, domain) +} + func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) { target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID) if err != nil { diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index ba4e1c805..16f48e5bf 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -425,7 +425,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { require.NoError(t, err) pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100) require.NoError(t, err) - srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) + srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil) return srv } @@ -706,7 +706,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { require.NoError(t, err) pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) require.NoError(t, err) - proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) + proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) require.NoError(t, err) @@ -1138,7 +1138,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) { require.NoError(t, err) pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) require.NoError(t, err) - proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) + proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) require.NoError(t, err) diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index eb13a15e3..9a1add57f 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -94,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.ProxyManager()) if err != nil { log.Fatalf("failed to create API handler: %v", err) } @@ -168,7 +168,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer { - proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager()) + proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store()) s.AfterInit(func(s *BaseServer) { proxyService.SetServiceManager(s.ServiceManager()) proxyService.SetProxyController(s.ServiceProxyController()) diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index e2d0f1abe..2a445bd6d 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -9,6 +9,7 @@ import ( "encoding/hex" "errors" "fmt" + "net" "net/url" "strings" "sync" @@ -51,6 +52,11 @@ type ClusterInfo struct { ConnectedProxies int } +// ProxyTokenChecker checks whether a proxy access token is still valid. +type ProxyTokenChecker interface { + IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) +} + // ProxyServiceServer implements the ProxyService gRPC server type ProxyServiceServer struct { proto.UnimplementedProxyServiceServer @@ -79,6 +85,9 @@ type ProxyServiceServer struct { // Store for one-time authentication tokens tokenStore *OneTimeTokenStore + // Checker for proxy access token validity + tokenChecker ProxyTokenChecker + // OIDC configuration for proxy authentication oidcConfig ProxyOIDCConfig @@ -90,16 +99,29 @@ const pkceVerifierTTL = 10 * time.Minute // proxyConnection represents a connected proxy type proxyConnection struct { - proxyID string - address string - stream proto.ProxyService_GetMappingUpdateServer - sendChan chan *proto.GetMappingUpdateResponse - ctx context.Context - cancel context.CancelFunc + proxyID string + address string + accountID *string + tokenID string + stream proto.ProxyService_GetMappingUpdateServer + sendChan chan *proto.GetMappingUpdateResponse + ctx context.Context + cancel context.CancelFunc +} + +func enforceAccountScope(ctx context.Context, requestAccountID string) error { + token := GetProxyTokenFromContext(ctx) + if token == nil || token.AccountID == nil { + return nil + } + if requestAccountID == "" || *token.AccountID != requestAccountID { + return status.Errorf(codes.PermissionDenied, "account-scoped token cannot access account %s", requestAccountID) + } + return nil } // NewProxyServiceServer creates a new proxy service server. -func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer { +func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer { ctx := context.Background() s := &ProxyServiceServer{ accessLogManager: accessLogMgr, @@ -109,6 +131,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT peersManager: peersManager, usersManager: usersManager, proxyManager: proxyMgr, + tokenChecker: tokenChecker, } go s.cleanupStaleProxies(ctx) return s @@ -155,14 +178,57 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest return status.Errorf(codes.InvalidArgument, "proxy address is invalid") } + var accountID *string + token := GetProxyTokenFromContext(ctx) + if token != nil && token.AccountID != nil { + accountID = token.AccountID + + existingProxy, _ := s.proxyManager.GetAccountProxy(ctx, *accountID) + if existingProxy != nil && existingProxy.ID != proxyID { + if existingProxy.Status == proxy.StatusConnected { + return status.Errorf(codes.ResourceExhausted, "limit of 1 self-hosted proxy per account") + } + if err := s.proxyManager.DeleteProxy(ctx, existingProxy.ID); err != nil { + log.WithContext(ctx).Warnf("failed to cleanup disconnected proxy %s: %v", existingProxy.ID, err) + } + } + + available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID) + if err != nil { + return status.Errorf(codes.Internal, "check cluster address: %v", err) + } + if !available { + return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress) + } + } + + var tokenID string + if token != nil { + tokenID = token.ID + } + connCtx, cancel := context.WithCancel(ctx) conn := &proxyConnection{ - proxyID: proxyID, - address: proxyAddress, - stream: stream, - sendChan: make(chan *proto.GetMappingUpdateResponse, 100), - ctx: connCtx, - cancel: cancel, + proxyID: proxyID, + address: proxyAddress, + accountID: accountID, + tokenID: tokenID, + stream: stream, + sendChan: make(chan *proto.GetMappingUpdateResponse, 100), + ctx: connCtx, + cancel: cancel, + } + + // Register proxy in database + if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, accountID); err != nil { + if accountID != nil { + cancel() + if strings.Contains(err.Error(), "UNIQUE constraint") || strings.Contains(err.Error(), "duplicate key") || strings.Contains(err.Error(), "idx_proxy_account_id_unique") { + return status.Errorf(codes.ResourceExhausted, "limit of 1 self-hosted proxy per account") + } + return status.Errorf(codes.Internal, "failed to register BYOD proxy: %v", err) + } + log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err) } s.connectedProxies.Store(proxyID, conn) @@ -170,15 +236,11 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) } - // Register proxy in database - if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil { - log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err) - } - log.WithFields(log.Fields{ "proxy_id": proxyID, "address": proxyAddress, "cluster_addr": proxyAddress, + "account_id": accountID, "total_proxies": len(s.GetConnectedProxies()), }).Info("Proxy registered in cluster") defer func() { @@ -203,7 +265,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest go s.sender(conn, errChan) // Start heartbeat goroutine - go s.heartbeat(connCtx, proxyID) + go s.heartbeat(connCtx, conn) select { case err := <-errChan: @@ -213,16 +275,30 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest } } -// heartbeat updates the proxy's last_seen timestamp every minute -func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) { +// heartbeat updates the proxy's last_seen timestamp every minute and +// validates that the proxy's access token is still valid. +func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection) { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: - if err := s.proxyManager.Heartbeat(ctx, proxyID); err != nil { - log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err) + if err := s.proxyManager.Heartbeat(ctx, conn.proxyID); err != nil { + log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", conn.proxyID, err) + } + + if conn.tokenID != "" && s.tokenChecker != nil { + valid, err := s.tokenChecker.IsProxyAccessTokenValid(ctx, conn.tokenID) + if err != nil { + log.WithContext(ctx).Warnf("failed to check token validity for proxy %s: %v", conn.proxyID, err) + continue + } + if !valid { + log.WithContext(ctx).Warnf("proxy %s token revoked or expired, disconnecting", conn.proxyID) + conn.cancel() + return + } } case <-ctx.Done(): return @@ -232,8 +308,15 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) { // sendSnapshot sends the initial snapshot of services to the connecting proxy. // Only services matching the proxy's cluster address are sent. +// For BYOD proxies (account-scoped), only account services are sent. func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { - services, err := s.serviceManager.GetGlobalServices(ctx) + var services []*rpservice.Service + var err error + if conn.accountID != nil { + services, err = s.serviceManager.GetAccountServices(ctx, *conn.accountID) + } else { + services, err = s.serviceManager.GetGlobalServices(ctx) + } if err != nil { return fmt.Errorf("get services from store: %w", err) } @@ -295,8 +378,14 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec return nil } -// isProxyAddressValid validates a proxy address +// isProxyAddressValid validates a proxy address (domain name or IP address) func isProxyAddressValid(addr string) bool { + if addr == "" { + return false + } + if net.ParseIP(addr) != nil { + return true + } _, err := domain.ValidateDomains([]string{addr}) return err == nil } @@ -320,6 +409,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) { accessLog := req.GetLog() + if err := enforceAccountScope(ctx, accessLog.GetAccountId()); err != nil { + return nil, err + } + fields := log.Fields{ "service_id": accessLog.GetServiceId(), "account_id": accessLog.GetAccountId(), @@ -357,10 +450,18 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA // Management should call this when services are created/updated/removed. // For create/update operations a unique one-time auth token is generated per // proxy so that every replica can independently authenticate with management. +// BYOD proxies only receive updates for their own account's services. func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) { log.Debugf("Broadcasting service update to all connected proxy servers") + var updateAccountID string + if len(update.Mapping) > 0 { + updateAccountID = update.Mapping[0].AccountId + } s.connectedProxies.Range(func(key, value interface{}) bool { conn := value.(*proxyConnection) + if conn.accountID != nil && updateAccountID != "" && *conn.accountID != updateAccountID { + return true + } msg := s.perProxyMessage(update, conn.proxyID) if msg == nil { return true @@ -375,6 +476,16 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes }) } +// ForceDisconnect cancels the gRPC stream for a connected proxy, causing it to disconnect. +func (s *ProxyServiceServer) ForceDisconnect(proxyID string) { + if connVal, ok := s.connectedProxies.Load(proxyID); ok { + conn := connVal.(*proxyConnection) + conn.cancel() + s.connectedProxies.Delete(proxyID) + log.WithFields(log.Fields{"proxyID": proxyID}).Info("force disconnected proxy") + } +} + // GetConnectedProxies returns a list of connected proxy IDs func (s *ProxyServiceServer) GetConnectedProxies() []string { var proxies []string @@ -440,6 +551,9 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd for _, proxyID := range proxyIDs { if connVal, ok := s.connectedProxies.Load(proxyID); ok { conn := connVal.(*proxyConnection) + if conn.accountID != nil && update.AccountId != "" && *conn.accountID != update.AccountId { + continue + } msg := s.perProxyMessage(updateResponse, proxyID) if msg == nil { continue @@ -499,6 +613,10 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { } func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) if err != nil { log.WithContext(ctx).Debugf("failed to get service from store: %v", err) @@ -587,6 +705,10 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic // SendStatusUpdate handles status updates from proxy clients func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) { + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + accountID := req.GetAccountId() serviceID := req.GetServiceId() protoStatus := req.GetStatus() @@ -653,6 +775,10 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status { // CreateProxyPeer handles proxy peer creation with one-time token authentication func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) { + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + serviceID := req.GetServiceId() accountID := req.GetAccountId() token := req.GetToken() @@ -707,6 +833,10 @@ func strPtr(s string) *string { } func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) { + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + redirectURL, err := url.Parse(req.GetRedirectUrl()) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err) @@ -829,21 +959,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL // GenerateSessionToken creates a signed session JWT for the given domain and user. func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) { - // Find the service by domain to get its signing key - services, err := s.serviceManager.GetGlobalServices(ctx) + service, err := s.getServiceByDomain(ctx, domain) if err != nil { - return "", fmt.Errorf("get services: %w", err) - } - - var service *rpservice.Service - for _, svc := range services { - if svc.Domain == domain { - service = svc - break - } - } - if service == nil { - return "", fmt.Errorf("service not found for domain: %s", domain) + return "", fmt.Errorf("service not found for domain %s: %w", domain, err) } if service.SessionPrivateKey == "" { @@ -941,6 +1059,10 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val }, nil } + if err := enforceAccountScope(ctx, service.AccountID); err != nil { + return nil, err + } + pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey) if err != nil { log.WithFields(log.Fields{ @@ -1024,18 +1146,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val } func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) { - services, err := s.serviceManager.GetGlobalServices(ctx) - if err != nil { - return nil, fmt.Errorf("get services: %w", err) - } - - for _, service := range services { - if service.Domain == domain { - return service, nil - } - } - - return nil, fmt.Errorf("service not found for domain: %s", domain) + return s.serviceManager.GetServiceByDomain(ctx, domain) } func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error { diff --git a/management/internals/shared/grpc/proxy_address_test.go b/management/internals/shared/grpc/proxy_address_test.go new file mode 100644 index 000000000..8e3181f78 --- /dev/null +++ b/management/internals/shared/grpc/proxy_address_test.go @@ -0,0 +1,29 @@ +package grpc + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsProxyAddressValid(t *testing.T) { + tests := []struct { + name string + addr string + valid bool + }{ + {name: "valid domain", addr: "eu.proxy.netbird.io", valid: true}, + {name: "valid subdomain", addr: "byod.proxy.example.com", valid: true}, + {name: "valid IPv4", addr: "10.0.0.1", valid: true}, + {name: "valid IPv4 public", addr: "203.0.113.10", valid: true}, + {name: "valid IPv6", addr: "::1", valid: true}, + {name: "valid IPv6 full", addr: "2001:db8::1", valid: true}, + {name: "empty string", addr: "", valid: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.valid, isProxyAddressValid(tt.addr)) + }) + } +} diff --git a/management/internals/shared/grpc/proxy_auth.go b/management/internals/shared/grpc/proxy_auth.go index dd593dfa0..9888e8eee 100644 --- a/management/internals/shared/grpc/proxy_auth.go +++ b/management/internals/shared/grpc/proxy_auth.go @@ -153,9 +153,6 @@ func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types return nil, status.Errorf(codes.Unauthenticated, "invalid token") } - // TODO: Enforce AccountID scope for "bring your own proxy" feature. - // Currently tokens are management-wide; AccountID field is reserved for future use. - if !token.IsValid() { return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked") } diff --git a/management/internals/shared/grpc/proxy_group_access_test.go b/management/internals/shared/grpc/proxy_group_access_test.go index 22fe4506b..b2cb64581 100644 --- a/management/internals/shared/grpc/proxy_group_access_test.go +++ b/management/internals/shared/grpc/proxy_group_access_test.go @@ -90,6 +90,17 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {} +func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain string) (*service.Service, error) { + for _, services := range m.proxiesByAccount { + for _, svc := range services { + if svc.Domain == domain { + return svc, nil + } + } + } + return nil, errors.New("service not found for domain: " + domain) +} + type mockUsersManager struct { users map[string]*types.User err error diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index 647e8443b..d2d671b12 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/server/store" @@ -44,7 +45,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup { pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) require.NoError(t, err) - proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager) + proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil) proxyService.SetServiceManager(serviceManager) createTestProxies(t, ctx, testStore) @@ -320,9 +321,13 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {} +func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + return m.store.GetServiceByDomain(ctx, domain) +} + type testValidateSessionProxyManager struct{} -func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error { +func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string) error { return nil } @@ -338,10 +343,30 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co return nil, nil } +func (m *testValidateSessionProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) { + return nil, nil +} + func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error { return nil } +func (m *testValidateSessionProxyManager) GetAccountProxy(_ context.Context, _ string) (*nbproxy.Proxy, error) { + return nil, nil +} + +func (m *testValidateSessionProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) { + return 0, nil +} + +func (m *testValidateSessionProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) { + return true, nil +} + +func (m *testValidateSessionProxyManager) DeleteProxy(_ context.Context, _ string) error { + return nil +} + type testValidateSessionUsersManager struct { store store.Store } diff --git a/management/server/account_test.go b/management/server/account_test.go index fdec43617..fb42a47bf 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3133,7 +3133,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU return nil, nil, err } - proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager) + proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil) proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{}) if err != nil { return nil, nil, err diff --git a/management/server/http/handler.go b/management/server/http/handler.go index ddeda6d7f..2674f77fd 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -19,6 +19,9 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/selfhostedproxy" reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" @@ -73,7 +76,7 @@ const ( ) // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. -func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) { +func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, proxyMgr rpproxy.Manager) (http.Handler, error) { // Register bypass paths for unauthenticated endpoints if err := bypass.AddBypassPath("/api/instance"); err != nil { @@ -177,6 +180,11 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router) } + proxytoken.RegisterEndpoints(accountManager.GetStore(), permissionsManager, router) + if proxyMgr != nil && serviceManager != nil { + selfhostedproxy.RegisterEndpoints(proxyMgr, serviceManager, permissionsManager, proxyGRPCServer, router) + } + // Register OAuth callback handler for proxy authentication if proxyGRPCServer != nil { oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies) diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index 3bed54e80..f8b1e0f74 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -214,6 +214,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { nil, usersManager, nil, + nil, ) proxyService.SetServiceManager(&testServiceManager{store: testStore}) @@ -433,6 +434,10 @@ func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ stri func (m *testServiceManager) StartExposeReaper(_ context.Context) {} +func (m *testServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + return m.store.GetServiceByDomain(ctx, domain) +} + func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string { t.Helper() diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 462013963..d205f89e5 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -107,7 +107,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee if err != nil { t.Fatalf("Failed to create proxy manager: %v", err) } - proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) + proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil) domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) if err != nil { @@ -133,7 +133,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 5997c10e2..3523a8aba 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -4464,6 +4464,47 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e return nil } +func (s *SqlStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) { + tx := s.db.WithContext(ctx) + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var tokens []*types.ProxyAccessToken + result := tx.Where("account_id = ?", accountID).Find(&tokens) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "get proxy access tokens by account: %v", result.Error) + } + + return tokens, nil +} + +func (s *SqlStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) { + token, err := s.GetProxyAccessTokenByID(ctx, LockingStrengthNone, tokenID) + if err != nil { + return false, err + } + return token.IsValid(), nil +} + +func (s *SqlStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) { + tx := s.db.WithContext(ctx) + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var token types.ProxyAccessToken + result := tx.Take(&token, idQueryCondition, tokenID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "proxy access token not found") + } + return nil, status.Errorf(status.Internal, "get proxy access token by ID: %v", result.Error) + } + + return &token, nil +} + // MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token. func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error { result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}). @@ -5370,7 +5411,25 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { result := s.db.WithContext(ctx).Save(p) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error) - return status.Errorf(status.Internal, "failed to save proxy") + return status.Errorf(status.Internal, "failed to save proxy: %v", result.Error) + } + return nil +} + +// DisconnectProxy updates only the status, disconnected_at, and last_seen fields +func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID string) error { + now := time.Now() + result := s.db.WithContext(ctx). + Model(&proxy.Proxy{}). + Where("id = ?", proxyID). + Updates(map[string]interface{}{ + "status": proxy.StatusDisconnected, + "disconnected_at": now, + "last_seen": now, + }) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to disconnect proxy: %v", result.Error) + return status.Errorf(status.Internal, "failed to disconnect proxy") } return nil } @@ -5379,7 +5438,7 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error { result := s.db.WithContext(ctx). Model(&proxy.Proxy{}). - Where("id = ? AND status = ?", proxyID, "connected"). + Where("id = ? AND status = ?", proxyID, proxy.StatusConnected). Update("last_seen", time.Now()) if result.Error != nil { @@ -5395,7 +5454,7 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string result := s.db.WithContext(ctx). Model(&proxy.Proxy{}). - Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)). + Where("status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-2*time.Minute)). Distinct("cluster_address"). Pluck("cluster_address", &addresses) @@ -5407,6 +5466,66 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string return addresses, nil } +func (s *SqlStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + var addresses []string + + result := s.db.WithContext(ctx). + Model(&proxy.Proxy{}). + Where("account_id = ? AND status = ? AND last_seen > ?", accountID, proxy.StatusConnected, time.Now().Add(-2*time.Minute)). + Distinct("cluster_address"). + Pluck("cluster_address", &addresses) + + if result.Error != nil { + return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses for account") + } + + return addresses, nil +} + +func (s *SqlStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) { + var p proxy.Proxy + result := s.db.WithContext(ctx).Where("account_id = ?", accountID).Take(&p) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "proxy not found for account") + } + return nil, status.Errorf(status.Internal, "get proxy by account ID: %v", result.Error) + } + return &p, nil +} + +func (s *SqlStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) { + var count int64 + result := s.db.WithContext(ctx).Model(&proxy.Proxy{}).Where("account_id = ?", accountID).Count(&count) + if result.Error != nil { + return 0, status.Errorf(status.Internal, "count proxies by account ID: %v", result.Error) + } + return count, nil +} + +func (s *SqlStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) { + var count int64 + result := s.db.WithContext(ctx). + Model(&proxy.Proxy{}). + Where("cluster_address = ? AND (account_id IS NULL OR account_id != ?)", clusterAddress, accountID). + Count(&count) + if result.Error != nil { + return false, status.Errorf(status.Internal, "check cluster address conflict: %v", result.Error) + } + return count > 0, nil +} + +func (s *SqlStore) DeleteProxy(ctx context.Context, proxyID string) error { + result := s.db.WithContext(ctx).Where(idQueryCondition, proxyID).Delete(&proxy.Proxy{}) + if result.Error != nil { + return status.Errorf(status.Internal, "delete proxy: %v", result.Error) + } + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "proxy not found") + } + return nil +} + // CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error { cutoffTime := time.Now().Add(-inactivityDuration) diff --git a/management/server/store/store.go b/management/server/store/store.go index 1fa99fd05..14a64be2c 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -114,6 +114,9 @@ type Store interface { GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) + GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) + GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) + IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error RevokeProxyAccessToken(ctx context.Context, tokenID string) error MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error @@ -282,9 +285,15 @@ type Store interface { DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error SaveProxy(ctx context.Context, proxy *proxy.Proxy) error + DisconnectProxy(ctx context.Context, proxyID string) error UpdateProxyHeartbeat(ctx context.Context, proxyID string) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) + GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error + GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) + CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) + IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) + DeleteProxy(ctx context.Context, proxyID string) error GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) } diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 130df4485..97a5503f6 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -165,6 +165,21 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration) } +// CountProxiesByAccountID mocks base method. +func (m *MockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountProxiesByAccountID", ctx, accountID) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountProxiesByAccountID indicates an expected call of CountProxiesByAccountID. +func (mr *MockStoreMockRecorder) CountProxiesByAccountID(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountProxiesByAccountID", reflect.TypeOf((*MockStore)(nil).CountProxiesByAccountID), ctx, accountID) +} + // Close mocks base method. func (m *MockStore) Close(ctx context.Context) error { m.ctrl.T.Helper() @@ -1287,6 +1302,21 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx) } +// GetActiveProxyClusterAddressesForAccount mocks base method. +func (m *MockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveProxyClusterAddressesForAccount", ctx, accountID) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveProxyClusterAddressesForAccount indicates an expected call of GetActiveProxyClusterAddressesForAccount. +func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddressesForAccount", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddressesForAccount), ctx, accountID) +} + // GetAllAccounts mocks base method. func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account { m.ctrl.T.Helper() @@ -1331,6 +1361,51 @@ func (mr *MockStoreMockRecorder) GetAllProxyAccessTokens(ctx, lockStrength inter return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllProxyAccessTokens", reflect.TypeOf((*MockStore)(nil).GetAllProxyAccessTokens), ctx, lockStrength) } +// GetProxyAccessTokensByAccountID mocks base method. +func (m *MockStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types2.ProxyAccessToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxyAccessTokensByAccountID", ctx, lockStrength, accountID) + ret0, _ := ret[0].([]*types2.ProxyAccessToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProxyAccessTokensByAccountID indicates an expected call of GetProxyAccessTokensByAccountID. +func (mr *MockStoreMockRecorder) GetProxyAccessTokensByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokensByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokensByAccountID), ctx, lockStrength, accountID) +} + +// GetProxyAccessTokenByID mocks base method. +func (m *MockStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types2.ProxyAccessToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxyAccessTokenByID", ctx, lockStrength, tokenID) + ret0, _ := ret[0].(*types2.ProxyAccessToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProxyAccessTokenByID indicates an expected call of GetProxyAccessTokenByID. +func (mr *MockStoreMockRecorder) GetProxyAccessTokenByID(ctx, lockStrength, tokenID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByID), ctx, lockStrength, tokenID) +} + +// IsProxyAccessTokenValid mocks base method. +func (m *MockStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsProxyAccessTokenValid", ctx, tokenID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsProxyAccessTokenValid indicates an expected call of IsProxyAccessTokenValid. +func (mr *MockStoreMockRecorder) IsProxyAccessTokenValid(ctx, tokenID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsProxyAccessTokenValid", reflect.TypeOf((*MockStore)(nil).IsProxyAccessTokenValid), ctx, tokenID) +} + // GetAnyAccountID mocks base method. func (m *MockStore) GetAnyAccountID(ctx context.Context) (string, error) { m.ctrl.T.Helper() @@ -1901,6 +1976,50 @@ func (mr *MockStoreMockRecorder) GetProxyAccessTokenByHashedToken(ctx, lockStren return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByHashedToken", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByHashedToken), ctx, lockStrength, hashedToken) } +// GetProxyByAccountID mocks base method. +func (m *MockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxyByAccountID", ctx, accountID) + ret0, _ := ret[0].(*proxy.Proxy) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProxyByAccountID indicates an expected call of GetProxyByAccountID. +func (mr *MockStoreMockRecorder) GetProxyByAccountID(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyByAccountID), ctx, accountID) +} + +// IsClusterAddressConflicting mocks base method. +func (m *MockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsClusterAddressConflicting", ctx, clusterAddress, accountID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsClusterAddressConflicting indicates an expected call of IsClusterAddressConflicting. +func (mr *MockStoreMockRecorder) IsClusterAddressConflicting(ctx, clusterAddress, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressConflicting", reflect.TypeOf((*MockStore)(nil).IsClusterAddressConflicting), ctx, clusterAddress, accountID) +} + +// DeleteProxy mocks base method. +func (m *MockStore) DeleteProxy(ctx context.Context, proxyID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteProxy", ctx, proxyID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteProxy indicates an expected call of DeleteProxy. +func (mr *MockStoreMockRecorder) DeleteProxy(ctx, proxyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteProxy", reflect.TypeOf((*MockStore)(nil).DeleteProxy), ctx, proxyID) +} + // GetResourceGroups mocks base method. func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) { m.ctrl.T.Helper() @@ -2698,6 +2817,20 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy) } +// DisconnectProxy mocks base method. +func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DisconnectProxy indicates an expected call of DisconnectProxy. +func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID) +} + // SaveProxyAccessToken mocks base method. func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error { m.ctrl.T.Helper() diff --git a/proxy/management_byod_integration_test.go b/proxy/management_byod_integration_test.go new file mode 100644 index 000000000..d40178527 --- /dev/null +++ b/proxy/management_byod_integration_test.go @@ -0,0 +1,408 @@ +package proxy + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/base64" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" + grpcstatus "google.golang.org/grpc/status" + + proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/management/proto" +) + +type byodTestSetup struct { + store store.Store + proxyService *nbgrpc.ProxyServiceServer + grpcServer *grpc.Server + grpcAddr string + cleanup func() + + accountA string + accountB string + accountAToken types.PlainProxyToken + accountBToken types.PlainProxyToken + accountACluster string + accountBCluster string +} + +func setupBYODIntegrationTest(t *testing.T) *byodTestSetup { + t.Helper() + ctx := context.Background() + + testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + + accountAID := "byod-account-a" + accountBID := "byod-account-b" + + for _, acc := range []*types.Account{ + {Id: accountAID, Domain: "a.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()}, + {Id: accountBID, Domain: "b.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()}, + } { + require.NoError(t, testStore.SaveAccount(ctx, acc)) + } + + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + pubKey := base64.StdEncoding.EncodeToString(pub) + privKey := base64.StdEncoding.EncodeToString(priv) + + clusterA := "byod-a.proxy.test" + clusterB := "byod-b.proxy.test" + + services := []*service.Service{ + { + ID: "svc-a1", AccountID: accountAID, Name: "App A1", + Domain: "app1." + clusterA, ProxyCluster: clusterA, Enabled: true, + SessionPrivateKey: privKey, SessionPublicKey: pubKey, + Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.1", Port: 8080, Protocol: "http", TargetId: "peer-a1", TargetType: "peer", Enabled: true}}, + }, + { + ID: "svc-a2", AccountID: accountAID, Name: "App A2", + Domain: "app2." + clusterA, ProxyCluster: clusterA, Enabled: true, + SessionPrivateKey: privKey, SessionPublicKey: pubKey, + Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.2", Port: 8080, Protocol: "http", TargetId: "peer-a2", TargetType: "peer", Enabled: true}}, + }, + { + ID: "svc-b1", AccountID: accountBID, Name: "App B1", + Domain: "app1." + clusterB, ProxyCluster: clusterB, Enabled: true, + SessionPrivateKey: privKey, SessionPublicKey: pubKey, + Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.3", Port: 8080, Protocol: "http", TargetId: "peer-b1", TargetType: "peer", Enabled: true}}, + }, + } + for _, svc := range services { + require.NoError(t, testStore.CreateService(ctx, svc)) + } + + tokenA, err := types.CreateNewProxyAccessToken("byod-token-a", 0, &accountAID, "admin-a") + require.NoError(t, err) + require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenA.ProxyAccessToken)) + + tokenB, err := types.CreateNewProxyAccessToken("byod-token-b", 0, &accountBID, "admin-b") + require.NoError(t, err) + require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenB.ProxyAccessToken)) + + tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + + meter := noop.NewMeterProvider().Meter("test") + realProxyManager, err := proxymanager.NewManager(testStore, meter) + require.NoError(t, err) + + oidcConfig := nbgrpc.ProxyOIDCConfig{ + Issuer: "https://fake-issuer.example.com", + ClientID: "test-client", + HMACKey: []byte("test-hmac-key"), + } + + usersManager := users.NewManager(testStore) + + proxyService := nbgrpc.NewProxyServiceServer( + &testAccessLogManager{}, + tokenStore, + pkceStore, + oidcConfig, + nil, + usersManager, + realProxyManager, + nil, + ) + + svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore} + proxyService.SetServiceManager(svcMgr) + + proxyController := &testProxyController{} + proxyService.SetProxyController(proxyController) + + _, streamInterceptor, authClose := nbgrpc.NewProxyAuthInterceptors(testStore) + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + grpcServer := grpc.NewServer(grpc.StreamInterceptor(streamInterceptor)) + proto.RegisterProxyServiceServer(grpcServer, proxyService) + + go func() { + if err := grpcServer.Serve(lis); err != nil { + t.Logf("gRPC server error: %v", err) + } + }() + + return &byodTestSetup{ + store: testStore, + proxyService: proxyService, + grpcServer: grpcServer, + grpcAddr: lis.Addr().String(), + cleanup: func() { + grpcServer.GracefulStop() + authClose() + storeCleanup() + }, + accountA: accountAID, + accountB: accountBID, + accountAToken: tokenA.PlainToken, + accountBToken: tokenB.PlainToken, + accountACluster: clusterA, + accountBCluster: clusterB, + } +} + +func byodContext(ctx context.Context, token types.PlainProxyToken) context.Context { + md := metadata.Pairs("authorization", "Bearer "+string(token)) + return metadata.NewOutgoingContext(ctx, md) +} + +func receiveBYODMappings(t *testing.T, stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping { + t.Helper() + var mappings []*proto.ProxyMapping + for { + msg, err := stream.Recv() + require.NoError(t, err) + mappings = append(mappings, msg.GetMapping()...) + if msg.GetInitialSyncComplete() { + break + } + } + return mappings +} + +func TestIntegration_BYODProxy_ReceivesOnlyAccountServices(t *testing.T) { + setup := setupBYODIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(byodContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel() + + stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: "byod-proxy-a", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + mappings := receiveBYODMappings(t, stream) + + assert.Len(t, mappings, 2, "BYOD proxy should receive only account A's 2 services") + for _, m := range mappings { + assert.Equal(t, setup.accountA, m.GetAccountId(), "all mappings should belong to account A") + t.Logf("received mapping: id=%s domain=%s account=%s", m.GetId(), m.GetDomain(), m.GetAccountId()) + } + + ids := map[string]bool{} + for _, m := range mappings { + ids[m.GetId()] = true + } + assert.True(t, ids["svc-a1"], "should contain svc-a1") + assert.True(t, ids["svc-a2"], "should contain svc-a2") + assert.False(t, ids["svc-b1"], "should NOT contain account B's svc-b1") +} + +func TestIntegration_BYODProxy_AccountBReceivesOnlyItsServices(t *testing.T) { + setup := setupBYODIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(byodContext(context.Background(), setup.accountBToken), 5*time.Second) + defer cancel() + + stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: "byod-proxy-b", + Version: "test-v1", + Address: setup.accountBCluster, + }) + require.NoError(t, err) + + mappings := receiveBYODMappings(t, stream) + + assert.Len(t, mappings, 1, "BYOD proxy B should receive only 1 service") + assert.Equal(t, "svc-b1", mappings[0].GetId()) + assert.Equal(t, setup.accountB, mappings[0].GetAccountId()) +} + +func TestIntegration_BYODProxy_LimitOnePerAccount(t *testing.T) { + setup := setupBYODIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx1, cancel1 := context.WithTimeout(byodContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel1() + + stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{ + ProxyId: "byod-proxy-a-first", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + _ = receiveBYODMappings(t, stream1) + + ctx2, cancel2 := context.WithTimeout(byodContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel2() + + stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{ + ProxyId: "byod-proxy-a-second", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + _, err = stream2.Recv() + require.Error(t, err) + + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.ResourceExhausted, st.Code(), "second BYOD proxy should be rejected with ResourceExhausted") + t.Logf("expected rejection: %s", st.Message()) +} + +func TestIntegration_BYODProxy_ClusterAddressConflict(t *testing.T) { + setup := setupBYODIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx1, cancel1 := context.WithTimeout(byodContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel1() + + stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{ + ProxyId: "byod-proxy-a-cluster", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + _ = receiveBYODMappings(t, stream1) + + ctx2, cancel2 := context.WithTimeout(byodContext(context.Background(), setup.accountBToken), 5*time.Second) + defer cancel2() + + stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{ + ProxyId: "byod-proxy-b-conflict", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + _, err = stream2.Recv() + require.Error(t, err) + + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.AlreadyExists, st.Code(), "cluster address conflict should return AlreadyExists") + t.Logf("expected rejection: %s", st.Message()) +} + +func TestIntegration_BYODProxy_SameProxyReconnects(t *testing.T) { + setup := setupBYODIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + proxyID := "byod-proxy-reconnect" + + ctx1, cancel1 := context.WithTimeout(byodContext(context.Background(), setup.accountAToken), 5*time.Second) + stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{ + ProxyId: proxyID, + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + firstMappings := receiveBYODMappings(t, stream1) + cancel1() + + time.Sleep(200 * time.Millisecond) + + ctx2, cancel2 := context.WithTimeout(byodContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel2() + + stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{ + ProxyId: proxyID, + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + secondMappings := receiveBYODMappings(t, stream2) + + assert.Equal(t, len(firstMappings), len(secondMappings), "reconnect should receive same mappings") + + firstIDs := map[string]bool{} + for _, m := range firstMappings { + firstIDs[m.GetId()] = true + } + for _, m := range secondMappings { + assert.True(t, firstIDs[m.GetId()], "mapping %s should be present on reconnect", m.GetId()) + } +} + +func TestIntegration_BYODProxy_UnauthenticatedRejected(t *testing.T) { + setup := setupBYODIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: "no-auth-proxy", + Version: "test-v1", + Address: "some.cluster.io", + }) + require.NoError(t, err) + + _, err = stream.Recv() + require.Error(t, err) + + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Unauthenticated, st.Code()) +} diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 6a0ecce30..3bd9d47c4 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -139,6 +139,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { nil, usersManager, proxyManager, + nil, ) // Use store-backed service manager @@ -200,7 +201,7 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, // testProxyManager is a mock implementation of proxy.Manager for testing. type testProxyManager struct{} -func (m *testProxyManager) Connect(_ context.Context, _, _, _ string) error { +func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *string) error { return nil } @@ -216,10 +217,30 @@ func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]strin return nil, nil } +func (m *testProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) { + return nil, nil +} + func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error { return nil } +func (m *testProxyManager) GetAccountProxy(_ context.Context, _ string) (*nbproxy.Proxy, error) { + return nil, nil +} + +func (m *testProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) { + return 0, nil +} + +func (m *testProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) { + return true, nil +} + +func (m *testProxyManager) DeleteProxy(_ context.Context, _ string) error { + return nil +} + // testProxyController is a mock implementation of rpservice.ProxyController for testing. type testProxyController struct{} @@ -319,6 +340,10 @@ func (m *storeBackedServiceManager) StopServiceFromPeer(_ context.Context, _, _, func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {} +func (m *storeBackedServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + return m.store.GetServiceByDomain(ctx, domain) +} + func strPtr(s string) *string { return &s } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index c67231342..150f3af28 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -3151,6 +3151,86 @@ components: description: Whether link auth is enabled required: - enabled + ProxyTokenRequest: + type: object + properties: + name: + type: string + description: Human-readable token name + example: "my-proxy-token" + expires_in: + type: integer + description: Token expiration in seconds (0 = never expires) + example: 0 + required: + - name + ProxyToken: + type: object + properties: + id: + type: string + name: + type: string + expires_at: + type: string + format: date-time + created_at: + type: string + format: date-time + last_used: + type: string + format: date-time + revoked: + type: boolean + required: + - id + - name + - created_at + - revoked + ProxyTokenCreated: + type: object + description: Returned on creation — plain_token is shown only once + allOf: + - $ref: '#/components/schemas/ProxyToken' + - type: object + properties: + plain_token: + type: string + description: The plain text token (shown only once) + example: "nbx_abc123..." + required: + - plain_token + SelfHostedProxy: + type: object + properties: + id: + type: string + description: Proxy instance ID + cluster_address: + type: string + description: Cluster domain or IP address + example: "proxy.example.com" + ip_address: + type: string + description: Proxy IP address + status: + type: string + enum: [ connected, disconnected ] + last_seen: + type: string + format: date-time + connected_at: + type: string + format: date-time + service_count: + type: integer + description: Number of services routed through this proxy's cluster + required: + - id + - cluster_address + - status + - last_seen + - service_count ProxyCluster: type: object description: A proxy cluster represents a group of proxy nodes serving the same address @@ -9617,6 +9697,131 @@ paths: application/json: schema: $ref: '#/components/schemas/ErrorResponse' + /api/reverse-proxies/proxy-tokens: + get: + summary: List Proxy Tokens + description: Returns all proxy access tokens for the account + tags: [ Self-Hosted Proxies ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of proxy tokens + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/ProxyToken' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a Proxy Token + description: Generate an account-scoped proxy access token for self-hosted proxy registration + tags: [ Self-Hosted Proxies ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ProxyTokenRequest' + responses: + '200': + description: Proxy token created (plain token shown once) + content: + application/json: + schema: + $ref: '#/components/schemas/ProxyTokenCreated' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/reverse-proxies/proxy-tokens/{tokenId}: + delete: + summary: Revoke a Proxy Token + description: Revoke an account-scoped proxy access token + tags: [ Self-Hosted Proxies ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: tokenId + required: true + schema: + type: string + description: The unique identifier of the proxy token + responses: + '200': + description: Token revoked + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + /api/reverse-proxies/self-hosted-proxies: + get: + summary: List Self-Hosted Proxies + description: Returns self-hosted proxies registered for the account + tags: [ Self-Hosted Proxies ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of self-hosted proxies + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/SelfHostedProxy' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/reverse-proxies/self-hosted-proxies/{proxyId}: + delete: + summary: Delete a Self-Hosted Proxy + description: Remove a self-hosted proxy from the account + tags: [ Self-Hosted Proxies ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: proxyId + required: true + schema: + type: string + description: The unique identifier of the proxy + responses: + '200': + description: Proxy deleted + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" /api/reverse-proxies/services: get: summary: List all Services diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index f218679c0..45097f641 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -859,6 +859,24 @@ func (e ReverseProxyDomainType) Valid() bool { } } +// Defines values for SelfHostedProxyStatus. +const ( + SelfHostedProxyStatusConnected SelfHostedProxyStatus = "connected" + SelfHostedProxyStatusDisconnected SelfHostedProxyStatus = "disconnected" +) + +// Valid indicates whether the value is a known member of the SelfHostedProxyStatus enum. +func (e SelfHostedProxyStatus) Valid() bool { + switch e { + case SelfHostedProxyStatusConnected: + return true + case SelfHostedProxyStatusDisconnected: + return true + default: + return false + } +} + // Defines values for SentinelOneMatchAttributesNetworkStatus. const ( SentinelOneMatchAttributesNetworkStatusConnected SentinelOneMatchAttributesNetworkStatus = "connected" @@ -3292,6 +3310,38 @@ type ProxyCluster struct { ConnectedProxies int `json:"connected_proxies"` } +// ProxyToken defines model for ProxyToken. +type ProxyToken struct { + CreatedAt time.Time `json:"created_at"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Id string `json:"id"` + LastUsed *time.Time `json:"last_used,omitempty"` + Name string `json:"name"` + Revoked bool `json:"revoked"` +} + +// ProxyTokenCreated defines model for ProxyTokenCreated. +type ProxyTokenCreated struct { + CreatedAt time.Time `json:"created_at"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Id string `json:"id"` + LastUsed *time.Time `json:"last_used,omitempty"` + Name string `json:"name"` + + // PlainToken The plain text token (shown only once) + PlainToken string `json:"plain_token"` + Revoked bool `json:"revoked"` +} + +// ProxyTokenRequest defines model for ProxyTokenRequest. +type ProxyTokenRequest struct { + // ExpiresIn Token expiration in seconds (0 = never expires) + ExpiresIn *int `json:"expires_in,omitempty"` + + // Name Human-readable token name + Name string `json:"name"` +} + // Resource defines model for Resource. type Resource struct { // Id ID of the resource @@ -3461,6 +3511,27 @@ type ScimTokenResponse struct { AuthToken string `json:"auth_token"` } +// SelfHostedProxy defines model for SelfHostedProxy. +type SelfHostedProxy struct { + // ClusterAddress Cluster domain or IP address + ClusterAddress string `json:"cluster_address"` + ConnectedAt *time.Time `json:"connected_at,omitempty"` + + // Id Proxy instance ID + Id string `json:"id"` + + // IpAddress Proxy IP address + IpAddress *string `json:"ip_address,omitempty"` + LastSeen time.Time `json:"last_seen"` + + // ServiceCount Number of services routed through this proxy's cluster + ServiceCount int `json:"service_count"` + Status SelfHostedProxyStatus `json:"status"` +} + +// SelfHostedProxyStatus defines model for SelfHostedProxy.Status. +type SelfHostedProxyStatus string + // SentinelOneMatchAttributes Attribute conditions to match when approving agents type SentinelOneMatchAttributes struct { // ActiveThreats The maximum allowed number of active threats on the agent @@ -4481,6 +4552,9 @@ type PutApiPostureChecksPostureCheckIdJSONRequestBody = PostureCheckUpdate // PostApiReverseProxiesDomainsJSONRequestBody defines body for PostApiReverseProxiesDomains for application/json ContentType. type PostApiReverseProxiesDomainsJSONRequestBody = ReverseProxyDomainRequest +// PostApiReverseProxiesProxyTokensJSONRequestBody defines body for PostApiReverseProxiesProxyTokens for application/json ContentType. +type PostApiReverseProxiesProxyTokensJSONRequestBody = ProxyTokenRequest + // PostApiReverseProxiesServicesJSONRequestBody defines body for PostApiReverseProxiesServices for application/json ContentType. type PostApiReverseProxiesServicesJSONRequestBody = ServiceRequest