mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-14 12:49:57 +00:00
[proxy] feature: bring your own proxy (#5627)
This commit is contained in:
@@ -31,6 +31,7 @@ type store interface {
|
||||
|
||||
type proxyManager interface {
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
@@ -71,8 +72,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
var ret []*domain.Domain
|
||||
|
||||
// Add connected proxy clusters as free domains.
|
||||
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
// For BYOP accounts, only their own cluster is returned; otherwise shared clusters.
|
||||
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
return nil, err
|
||||
@@ -126,8 +127,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
// Verify the target cluster is in the available clusters
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
// Verify the target cluster is in the available clusters for this account
|
||||
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
@@ -273,7 +274,7 @@ func (m Manager) GetClusterDomains() []string {
|
||||
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
||||
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
||||
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
@@ -298,6 +299,17 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain
|
||||
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
|
||||
}
|
||||
|
||||
func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) {
|
||||
byopAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get BYOP cluster addresses: %w", err)
|
||||
}
|
||||
if len(byopAddresses) > 0 {
|
||||
return byopAddresses, nil
|
||||
}
|
||||
return m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
}
|
||||
|
||||
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
|
||||
bestCluster := ""
|
||||
bestLen := -1
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockProxyManager struct {
|
||||
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
if m.getActiveClusterAddressesFunc != nil {
|
||||
return m.getActiveClusterAddressesFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
if m.getActiveClusterAddressesForAccountFunc != nil {
|
||||
return m.getActiveClusterAddressesForAccountFunc(ctx, accountID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPProxy(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
assert.Equal(t, "acc-123", accID)
|
||||
return []string{"byop.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, errors.New("db error")
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "BYOP cluster addresses")
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return []string{"eu.proxy.netbird.io"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
@@ -11,15 +11,19 @@ import (
|
||||
|
||||
// Manager defines the interface for proxy operations
|
||||
type Manager interface {
|
||||
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error)
|
||||
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error)
|
||||
Disconnect(ctx context.Context, proxyID, sessionID string) error
|
||||
Heartbeat(ctx context.Context, p *Proxy) error
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveClusters(ctx context.Context) ([]Cluster, error)
|
||||
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
|
||||
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
|
||||
IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
|
||||
}
|
||||
|
||||
// OIDCValidationConfig contains the OIDC configuration needed for token validation.
|
||||
|
||||
@@ -16,11 +16,16 @@ type store interface {
|
||||
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
|
||||
}
|
||||
|
||||
// Manager handles all proxy operations
|
||||
@@ -44,7 +49,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
||||
|
||||
// Connect registers a new proxy connection in the database.
|
||||
// capabilities may be nil for old proxies that do not report them.
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||
func (m *Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||
now := time.Now()
|
||||
var caps proxy.Capabilities
|
||||
if capabilities != nil {
|
||||
@@ -55,9 +60,10 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
|
||||
SessionID: sessionID,
|
||||
ClusterAddress: clusterAddress,
|
||||
IPAddress: ipAddress,
|
||||
AccountID: accountID,
|
||||
LastSeen: now,
|
||||
ConnectedAt: &now,
|
||||
Status: "connected",
|
||||
Status: proxy.StatusConnected,
|
||||
Capabilities: caps,
|
||||
}
|
||||
|
||||
@@ -77,7 +83,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
|
||||
}
|
||||
|
||||
// Disconnect marks a proxy as disconnected in the database.
|
||||
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||
func (m *Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
|
||||
return err
|
||||
@@ -92,7 +98,7 @@ func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) erro
|
||||
}
|
||||
|
||||
// Heartbeat updates the proxy's last seen timestamp.
|
||||
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
func (m *Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||
return err
|
||||
@@ -104,7 +110,7 @@ func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
|
||||
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
@@ -113,16 +119,6 @@ func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
|
||||
func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) {
|
||||
clusters, err := m.store.GetActiveProxyClusters(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return clusters, nil
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts returns whether any active proxy in the cluster
|
||||
// supports custom ports. Returns nil when no proxy has reported capabilities.
|
||||
func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
||||
@@ -142,10 +138,44 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
|
||||
}
|
||||
|
||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err)
|
||||
return nil, err
|
||||
}
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
return m.store.GetProxyByAccountID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
|
||||
return m.store.CountProxiesByAccountID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return !conflicting, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
if err := m.store.DeleteAccountCluster(ctx, clusterAddress, accountID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete cluster %s for account %s: %v", clusterAddress, accountID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,337 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
)
|
||||
|
||||
type mockStore struct {
|
||||
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
|
||||
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
||||
}
|
||||
|
||||
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
if m.saveProxyFunc != nil {
|
||||
return m.saveProxyFunc(ctx, p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||
if m.disconnectProxyFunc != nil {
|
||||
return m.disconnectProxyFunc(ctx, proxyID, sessionID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
if m.updateProxyHeartbeatFunc != nil {
|
||||
return m.updateProxyHeartbeatFunc(ctx, p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
if m.getActiveProxyClusterAddressesFunc != nil {
|
||||
return m.getActiveProxyClusterAddressesFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
if m.getActiveProxyClusterAddressesForAccFunc != nil {
|
||||
return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
|
||||
if m.cleanupStaleProxiesFunc != nil {
|
||||
return m.cleanupStaleProxiesFunc(ctx, d)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
if m.getProxyByAccountIDFunc != nil {
|
||||
return m.getProxyByAccountIDFunc(ctx, accountID)
|
||||
}
|
||||
return nil, fmt.Errorf("proxy not found for account %s", accountID)
|
||||
}
|
||||
func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
|
||||
if m.countProxiesByAccountIDFunc != nil {
|
||||
return m.countProxiesByAccountIDFunc(ctx, accountID)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
if m.isClusterAddressConflictingFunc != nil {
|
||||
return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
func (m *mockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
if m.deleteAccountClusterFunc != nil {
|
||||
return m.deleteAccountClusterFunc(ctx, clusterAddress, accountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestManager(s store) *Manager {
|
||||
meter := noop.NewMeterProvider().Meter("test")
|
||||
m, err := NewManager(s, meter)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func TestConnect_WithAccountID(t *testing.T) {
|
||||
accountID := "acc-123"
|
||||
|
||||
var savedProxy *proxy.Proxy
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
|
||||
savedProxy = p
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", &accountID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, savedProxy)
|
||||
assert.Equal(t, "proxy-1", savedProxy.ID)
|
||||
assert.Equal(t, "session-1", savedProxy.SessionID)
|
||||
assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress)
|
||||
assert.Equal(t, "10.0.0.1", savedProxy.IPAddress)
|
||||
assert.Equal(t, &accountID, savedProxy.AccountID)
|
||||
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
|
||||
assert.NotNil(t, savedProxy.ConnectedAt)
|
||||
}
|
||||
|
||||
func TestConnect_WithoutAccountID(t *testing.T) {
|
||||
var savedProxy *proxy.Proxy
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
|
||||
savedProxy = p
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "eu.proxy.netbird.io", "10.0.0.1", nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, savedProxy)
|
||||
assert.Nil(t, savedProxy.AccountID)
|
||||
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
|
||||
}
|
||||
|
||||
func TestConnect_StoreError(t *testing.T) {
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error {
|
||||
return errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", nil, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestIsClusterAddressAvailable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
conflicting bool
|
||||
storeErr error
|
||||
wantResult bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "available - no conflict",
|
||||
conflicting: false,
|
||||
wantResult: true,
|
||||
},
|
||||
{
|
||||
name: "not available - conflict exists",
|
||||
conflicting: true,
|
||||
wantResult: false,
|
||||
},
|
||||
{
|
||||
name: "store error",
|
||||
storeErr: errors.New("db error"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) {
|
||||
return tt.conflicting, tt.storeErr
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123")
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantResult, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountAccountProxies(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
count int64
|
||||
storeErr error
|
||||
wantCount int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no proxies",
|
||||
count: 0,
|
||||
wantCount: 0,
|
||||
},
|
||||
{
|
||||
name: "one proxy",
|
||||
count: 1,
|
||||
wantCount: 1,
|
||||
},
|
||||
{
|
||||
name: "store error",
|
||||
storeErr: errors.New("db error"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) {
|
||||
return tt.count, tt.storeErr
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
count, err := mgr.CountAccountProxies(context.Background(), "acc-123")
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantCount, count)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountProxy(t *testing.T) {
|
||||
accountID := "acc-123"
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
expected := &proxy.Proxy{
|
||||
ID: "proxy-1",
|
||||
ClusterAddress: "byop.example.com",
|
||||
AccountID: &accountID,
|
||||
Status: proxy.StatusConnected,
|
||||
}
|
||||
s := &mockStore{
|
||||
getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) {
|
||||
assert.Equal(t, accountID, accID)
|
||||
return expected, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
p, err := mgr.GetAccountProxy(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, p)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) {
|
||||
return nil, errors.New("not found")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.GetAccountProxy(context.Background(), accountID)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteAccountCluster(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
var deletedCluster, deletedAccount string
|
||||
s := &mockStore{
|
||||
deleteAccountClusterFunc: func(_ context.Context, clusterAddress, accountID string) error {
|
||||
deletedCluster = clusterAddress
|
||||
deletedAccount = accountID
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "cluster.example.com", deletedCluster)
|
||||
assert.Equal(t, "acc-123", deletedAccount)
|
||||
})
|
||||
|
||||
t.Run("store error", func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
deleteAccountClusterFunc: func(_ context.Context, _, _ string) error {
|
||||
return errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetActiveClusterAddressesForAccount(t *testing.T) {
|
||||
expected := []string{"byop.example.com"}
|
||||
s := &mockStore{
|
||||
getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
assert.Equal(t, "acc-123", accID)
|
||||
return expected, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
@@ -93,18 +93,18 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) {
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities)
|
||||
ret0, _ := ret[0].(*Proxy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Connect indicates an expected call of Connect.
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities)
|
||||
}
|
||||
|
||||
// Disconnect mocks base method.
|
||||
@@ -136,19 +136,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
|
||||
}
|
||||
|
||||
// GetActiveClusters mocks base method.
|
||||
func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) {
|
||||
func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveClusters", ctx)
|
||||
ret0, _ := ret[0].([]Cluster)
|
||||
ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveClusters indicates an expected call of GetActiveClusters.
|
||||
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID)
|
||||
}
|
||||
|
||||
// Heartbeat mocks base method.
|
||||
@@ -165,6 +163,65 @@ func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p)
|
||||
}
|
||||
|
||||
// GetAccountProxy mocks base method.
|
||||
func (m *MockManager) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAccountProxy", ctx, accountID)
|
||||
ret0, _ := ret[0].(*Proxy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAccountProxy indicates an expected call of GetAccountProxy.
|
||||
func (mr *MockManagerMockRecorder) GetAccountProxy(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountProxy", reflect.TypeOf((*MockManager)(nil).GetAccountProxy), ctx, accountID)
|
||||
}
|
||||
|
||||
// CountAccountProxies mocks base method.
|
||||
func (m *MockManager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountAccountProxies", ctx, accountID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountAccountProxies indicates an expected call of CountAccountProxies.
|
||||
func (mr *MockManagerMockRecorder) CountAccountProxies(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountProxies", reflect.TypeOf((*MockManager)(nil).CountAccountProxies), ctx, accountID)
|
||||
}
|
||||
|
||||
// IsClusterAddressAvailable mocks base method.
|
||||
func (m *MockManager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsClusterAddressAvailable", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsClusterAddressAvailable indicates an expected call of IsClusterAddressAvailable.
|
||||
func (mr *MockManagerMockRecorder) IsClusterAddressAvailable(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressAvailable", reflect.TypeOf((*MockManager)(nil).IsClusterAddressAvailable), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster mocks base method.
|
||||
func (m *MockManager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
|
||||
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// MockController is a mock of Controller interface.
|
||||
type MockController struct {
|
||||
ctrl *gomock.Controller
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
package proxy
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
StatusConnected = "connected"
|
||||
StatusDisconnected = "disconnected"
|
||||
)
|
||||
|
||||
// Capabilities describes what a proxy can handle, as reported via gRPC.
|
||||
// Nil fields mean the proxy never reported this capability.
|
||||
@@ -21,6 +28,7 @@ type Proxy struct {
|
||||
SessionID string `gorm:"type:varchar(36)"`
|
||||
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||
IPAddress string `gorm:"type:varchar(45)"`
|
||||
AccountID *string `gorm:"type:varchar(255);index:idx_proxy_account_id"`
|
||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||
ConnectedAt *time.Time
|
||||
DisconnectedAt *time.Time
|
||||
@@ -36,6 +44,8 @@ func (Proxy) TableName() string {
|
||||
|
||||
// Cluster represents a group of proxy nodes serving the same address.
|
||||
type Cluster struct {
|
||||
ID string
|
||||
Address string
|
||||
ConnectedProxies int
|
||||
SelfHosted bool
|
||||
}
|
||||
|
||||
195
management/internals/modules/reverseproxy/proxytoken/handler.go
Normal file
195
management/internals/modules/reverseproxy/proxytoken/handler.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package proxytoken
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(s store.Store, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
h := &handler{store: s, permissionsManager: permissionsManager}
|
||||
router.HandleFunc("/reverse-proxies/proxy-tokens", h.listTokens).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/proxy-tokens", h.createToken).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/proxy-tokens/{tokenId}", h.revokeToken).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.ProxyTokenRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" || len(req.Name) > 255 {
|
||||
util.WriteErrorResponse("name is required and must be at most 255 characters", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
var expiresIn time.Duration
|
||||
if req.ExpiresIn != nil {
|
||||
if *req.ExpiresIn < 0 {
|
||||
util.WriteErrorResponse("expires_in must be non-negative", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
if *req.ExpiresIn > 0 {
|
||||
expiresIn = time.Duration(*req.ExpiresIn) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
accountID := userAuth.AccountId
|
||||
generated, err := types.CreateNewProxyAccessToken(req.Name, expiresIn, &accountID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to generate token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil {
|
||||
util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toProxyTokenCreatedResponse(generated)
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]api.ProxyToken, 0, len(tokens))
|
||||
for _, token := range tokens {
|
||||
resp = append(resp, toProxyTokenResponse(token))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
tokenID := mux.Vars(r)["tokenId"]
|
||||
if tokenID == "" {
|
||||
util.WriteErrorResponse("token ID is required", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
|
||||
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
|
||||
} else {
|
||||
util.WriteErrorResponse("failed to retrieve token", http.StatusInternalServerError, w)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if token.AccountID == nil || *token.AccountID != userAuth.AccountId {
|
||||
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil {
|
||||
util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken {
|
||||
resp := api.ProxyToken{
|
||||
Id: token.ID,
|
||||
Name: token.Name,
|
||||
Revoked: token.Revoked,
|
||||
}
|
||||
if !token.CreatedAt.IsZero() {
|
||||
resp.CreatedAt = token.CreatedAt
|
||||
}
|
||||
if token.ExpiresAt != nil {
|
||||
resp.ExpiresAt = token.ExpiresAt
|
||||
}
|
||||
if token.LastUsed != nil {
|
||||
resp.LastUsed = token.LastUsed
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func toProxyTokenCreatedResponse(generated *types.ProxyAccessTokenGenerated) api.ProxyTokenCreated {
|
||||
base := toProxyTokenResponse(&generated.ProxyAccessToken)
|
||||
plainToken := string(generated.PlainToken)
|
||||
return api.ProxyTokenCreated{
|
||||
Id: base.Id,
|
||||
Name: base.Name,
|
||||
CreatedAt: base.CreatedAt,
|
||||
ExpiresAt: base.ExpiresAt,
|
||||
LastUsed: base.LastUsed,
|
||||
Revoked: base.Revoked,
|
||||
PlainToken: plainToken,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,275 @@
|
||||
package proxytoken
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func authContext(accountID, userID string) context.Context {
|
||||
return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateToken_AccountScoped(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
var savedToken *types.ProxyAccessToken
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, token *types.ProxyAccessToken) error {
|
||||
savedToken = token
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": "my-token"}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp api.ProxyTokenCreated
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
|
||||
assert.NotEmpty(t, resp.PlainToken)
|
||||
assert.Equal(t, "my-token", resp.Name)
|
||||
assert.False(t, resp.Revoked)
|
||||
|
||||
require.NotNil(t, savedToken)
|
||||
require.NotNil(t, savedToken.AccountID)
|
||||
assert.Equal(t, accountID, *savedToken.AccountID)
|
||||
assert.Equal(t, "user-1", savedToken.CreatedBy)
|
||||
}
|
||||
|
||||
func TestCreateToken_WithExpiration(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
var savedToken *types.ProxyAccessToken
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, token *types.ProxyAccessToken) error {
|
||||
savedToken = token
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": "expiring-token", "expires_in": 3600}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
require.NotNil(t, savedToken)
|
||||
require.NotNil(t, savedToken.ExpiresAt)
|
||||
assert.True(t, savedToken.ExpiresAt.After(time.Now()))
|
||||
}
|
||||
|
||||
func TestCreateToken_EmptyName(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": ""}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func TestCreateToken_PermissionDenied(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil)
|
||||
|
||||
h := &handler{
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": "test"}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
}
|
||||
|
||||
func TestListTokens(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
now := time.Now()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokensByAccountID(gomock.Any(), store.LockingStrengthNone, accountID).Return([]*types.ProxyAccessToken{
|
||||
{ID: "tok-1", Name: "token-1", AccountID: &accountID, CreatedAt: now, Revoked: false},
|
||||
{ID: "tok-2", Name: "token-2", AccountID: &accountID, CreatedAt: now, Revoked: true},
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/reverse-proxies/proxy-tokens", nil)
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.listTokens(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp []api.ProxyToken
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
require.Len(t, resp, 2)
|
||||
assert.Equal(t, "tok-1", resp[0].Id)
|
||||
assert.False(t, resp[0].Revoked)
|
||||
assert.Equal(t, "tok-2", resp[1].Id)
|
||||
assert.True(t, resp[1].Revoked)
|
||||
}
|
||||
|
||||
func TestRevokeToken_Success(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||
ID: "tok-1",
|
||||
Name: "test-token",
|
||||
AccountID: &accountID,
|
||||
}, nil)
|
||||
mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.revokeToken(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestRevokeToken_WrongAccount(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
otherAccount := "acc-other"
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||
ID: "tok-1",
|
||||
AccountID: &otherAccount,
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.revokeToken(w, req)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestRevokeToken_ManagementWideToken(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||
ID: "tok-1",
|
||||
AccountID: nil,
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.revokeToken(w, req)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
type Manager interface {
|
||||
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
|
||||
DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error
|
||||
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
|
||||
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
|
||||
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||
@@ -28,4 +29,5 @@ type Manager interface {
|
||||
RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
|
||||
StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
|
||||
StartExposeReaper(ctx context.Context)
|
||||
GetServiceByDomain(ctx context.Context, domain string) (*Service, error)
|
||||
}
|
||||
|
||||
@@ -79,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID inte
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster mocks base method.
|
||||
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, accountID, userID, clusterAddress)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
|
||||
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID, clusterAddress interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
|
||||
}
|
||||
|
||||
// DeleteService mocks base method.
|
||||
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -138,6 +152,21 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||
ret0, _ := ret[0].(*Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
|
||||
}
|
||||
|
||||
// GetGlobalServices mocks base method.
|
||||
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -35,6 +35,7 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
|
||||
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
|
||||
|
||||
router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/clusters/{clusterAddress}", h.deleteCluster).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
|
||||
@@ -195,10 +196,33 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
|
||||
for _, c := range clusters {
|
||||
apiClusters = append(apiClusters, api.ProxyCluster{
|
||||
Id: c.ID,
|
||||
Address: c.Address,
|
||||
ConnectedProxies: c.ConnectedProxies,
|
||||
SelfHosted: c.SelfHosted,
|
||||
})
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, apiClusters)
|
||||
}
|
||||
|
||||
func (h *handler) deleteCluster(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
clusterAddress := mux.Vars(r)["clusterAddress"]
|
||||
if clusterAddress == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "cluster address is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.DeleteAccountCluster(r.Context(), userAuth.AccountId, userAuth.UserId, clusterAddress); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
@@ -122,7 +122,21 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetActiveProxyClusters(ctx)
|
||||
return m.store.GetActiveProxyClusters(ctx, accountID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster removes all proxy registrations for the given cluster address
|
||||
// owned by the account.
|
||||
func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.DeleteAccountCluster(ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||
@@ -986,6 +1000,10 @@ func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
|
||||
@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
t.Helper()
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
return srv
|
||||
}
|
||||
|
||||
@@ -714,7 +714,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
@@ -1138,7 +1138,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
Reference in New Issue
Block a user