mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[proxy] feature: bring your own proxy
This commit is contained in:
@@ -31,6 +31,7 @@ type store interface {
|
||||
|
||||
type proxyManager interface {
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
@@ -68,8 +69,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
var ret []*domain.Domain
|
||||
|
||||
// Add connected proxy clusters as free domains.
|
||||
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
// For BYOD accounts, only their own cluster is returned; otherwise shared clusters.
|
||||
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
return nil, err
|
||||
@@ -112,8 +113,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
// Verify the target cluster is in the available clusters
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
// Verify the target cluster is in the available clusters for this account
|
||||
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
@@ -259,7 +260,7 @@ func (m Manager) GetClusterDomains() []string {
|
||||
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
||||
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
||||
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
@@ -284,6 +285,17 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain
|
||||
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
|
||||
}
|
||||
|
||||
func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) {
|
||||
byodAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get BYOD cluster addresses: %w", err)
|
||||
}
|
||||
if len(byodAddresses) > 0 {
|
||||
return byodAddresses, nil
|
||||
}
|
||||
return m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
}
|
||||
|
||||
func extractClusterFromCustomDomains(domain string, customDomains []*domain.Domain) (string, bool) {
|
||||
for _, customDomain := range customDomains {
|
||||
if strings.HasSuffix(domain, "."+customDomain.Domain) {
|
||||
|
||||
@@ -0,0 +1,144 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockProxyManager struct {
|
||||
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
if m.getActiveClusterAddressesFunc != nil {
|
||||
return m.getActiveClusterAddressesFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
if m.getActiveClusterAddressesForAccountFunc != nil {
|
||||
return m.getActiveClusterAddressesForAccountFunc(ctx, accountID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYODProxy(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
assert.Equal(t, "acc-123", accID)
|
||||
return []string{"byod.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
t.Fatal("should not call GetActiveClusterAddresses when BYOD addresses exist")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"byod.example.com"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_NoBYOD_FallbackToShared(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYODError_FallbackToShared(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, errors.New("db error")
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return []string{"eu.proxy.netbird.io"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYODEmptySlice_FallbackToShared(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return []string{"eu.proxy.netbird.io"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
func TestExtractClusterFromFreeDomain(t *testing.T) {
|
||||
clusters := []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
wantCluster string
|
||||
wantOK bool
|
||||
}{
|
||||
{
|
||||
name: "matches EU cluster",
|
||||
domain: "myapp.abc123.eu.proxy.netbird.io",
|
||||
wantCluster: "eu.proxy.netbird.io",
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "matches US cluster",
|
||||
domain: "myapp.xyz789.us.proxy.netbird.io",
|
||||
wantCluster: "us.proxy.netbird.io",
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "no match - custom domain",
|
||||
domain: "app.example.com",
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "no match - partial cluster name",
|
||||
domain: "proxy.netbird.io",
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "exact cluster name - no prefix",
|
||||
domain: "eu.proxy.netbird.io",
|
||||
wantOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cluster, ok := ExtractClusterFromFreeDomain(tt.domain, clusters)
|
||||
assert.Equal(t, tt.wantOK, ok)
|
||||
if tt.wantOK {
|
||||
assert.Equal(t, tt.wantCluster, cluster)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -11,11 +11,16 @@ import (
|
||||
|
||||
// Manager defines the interface for proxy operations
|
||||
type Manager interface {
|
||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string) error
|
||||
Disconnect(ctx context.Context, proxyID string) error
|
||||
Heartbeat(ctx context.Context, proxyID string) error
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
|
||||
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
|
||||
IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
DeleteProxy(ctx context.Context, proxyID string) error
|
||||
}
|
||||
|
||||
// OIDCValidationConfig contains the OIDC configuration needed for token validation.
|
||||
|
||||
@@ -13,9 +13,15 @@ import (
|
||||
// store defines the interface for proxy persistence operations
|
||||
type store interface {
|
||||
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
||||
DisconnectProxy(ctx context.Context, proxyID string) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
DeleteProxy(ctx context.Context, proxyID string) error
|
||||
}
|
||||
|
||||
// Manager handles all proxy operations
|
||||
@@ -38,15 +44,16 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
||||
}
|
||||
|
||||
// Connect registers a new proxy connection in the database
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
func (m *Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string) error {
|
||||
now := time.Now()
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
ClusterAddress: clusterAddress,
|
||||
IPAddress: ipAddress,
|
||||
AccountID: accountID,
|
||||
LastSeen: now,
|
||||
ConnectedAt: &now,
|
||||
Status: "connected",
|
||||
Status: proxy.StatusConnected,
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
@@ -64,16 +71,8 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
|
||||
}
|
||||
|
||||
// Disconnect marks a proxy as disconnected in the database
|
||||
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
now := time.Now()
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
Status: "disconnected",
|
||||
DisconnectedAt: &now,
|
||||
LastSeen: now,
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
func (m *Manager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
if err := m.store.DisconnectProxy(ctx, proxyID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
|
||||
return err
|
||||
}
|
||||
@@ -86,7 +85,7 @@ func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
}
|
||||
|
||||
// Heartbeat updates the proxy's last seen timestamp
|
||||
func (m Manager) Heartbeat(ctx context.Context, proxyID string) error {
|
||||
func (m *Manager) Heartbeat(ctx context.Context, proxyID string) error {
|
||||
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||
return err
|
||||
@@ -96,7 +95,7 @@ func (m Manager) Heartbeat(ctx context.Context, proxyID string) error {
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
|
||||
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
@@ -106,10 +105,44 @@ func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error
|
||||
}
|
||||
|
||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err)
|
||||
return nil, err
|
||||
}
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
return m.store.GetProxyByAccountID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
|
||||
return m.store.CountProxiesByAccountID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return !conflicting, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteProxy(ctx context.Context, proxyID string) error {
|
||||
if err := m.store.DeleteProxy(ctx, proxyID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete proxy %s: %v", proxyID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,321 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
)
|
||||
|
||||
type mockStore struct {
|
||||
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
disconnectProxyFunc func(ctx context.Context, proxyID string) error
|
||||
updateProxyHeartbeatFunc func(ctx context.Context, proxyID string) error
|
||||
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
deleteProxyFunc func(ctx context.Context, proxyID string) error
|
||||
}
|
||||
|
||||
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
if m.saveProxyFunc != nil {
|
||||
return m.saveProxyFunc(ctx, p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID string) error {
|
||||
if m.disconnectProxyFunc != nil {
|
||||
return m.disconnectProxyFunc(ctx, proxyID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
|
||||
if m.updateProxyHeartbeatFunc != nil {
|
||||
return m.updateProxyHeartbeatFunc(ctx, proxyID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
if m.getActiveProxyClusterAddressesFunc != nil {
|
||||
return m.getActiveProxyClusterAddressesFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
if m.getActiveProxyClusterAddressesForAccFunc != nil {
|
||||
return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
|
||||
if m.cleanupStaleProxiesFunc != nil {
|
||||
return m.cleanupStaleProxiesFunc(ctx, d)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
if m.getProxyByAccountIDFunc != nil {
|
||||
return m.getProxyByAccountIDFunc(ctx, accountID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
|
||||
if m.countProxiesByAccountIDFunc != nil {
|
||||
return m.countProxiesByAccountIDFunc(ctx, accountID)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
if m.isClusterAddressConflictingFunc != nil {
|
||||
return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
func (m *mockStore) DeleteProxy(ctx context.Context, proxyID string) error {
|
||||
if m.deleteProxyFunc != nil {
|
||||
return m.deleteProxyFunc(ctx, proxyID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestManager(s store) *Manager {
|
||||
meter := noop.NewMeterProvider().Meter("test")
|
||||
m, err := NewManager(s, meter)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func TestConnect_WithAccountID(t *testing.T) {
|
||||
accountID := "acc-123"
|
||||
|
||||
var savedProxy *proxy.Proxy
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
|
||||
savedProxy = p
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.Connect(context.Background(), "proxy-1", "cluster.example.com", "10.0.0.1", &accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, savedProxy)
|
||||
assert.Equal(t, "proxy-1", savedProxy.ID)
|
||||
assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress)
|
||||
assert.Equal(t, "10.0.0.1", savedProxy.IPAddress)
|
||||
assert.Equal(t, &accountID, savedProxy.AccountID)
|
||||
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
|
||||
assert.NotNil(t, savedProxy.ConnectedAt)
|
||||
}
|
||||
|
||||
func TestConnect_WithoutAccountID(t *testing.T) {
|
||||
var savedProxy *proxy.Proxy
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
|
||||
savedProxy = p
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.Connect(context.Background(), "proxy-1", "eu.proxy.netbird.io", "10.0.0.1", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, savedProxy)
|
||||
assert.Nil(t, savedProxy.AccountID)
|
||||
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
|
||||
}
|
||||
|
||||
func TestConnect_StoreError(t *testing.T) {
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error {
|
||||
return errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.Connect(context.Background(), "proxy-1", "cluster.example.com", "10.0.0.1", nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestIsClusterAddressAvailable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
conflicting bool
|
||||
storeErr error
|
||||
wantResult bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "available - no conflict",
|
||||
conflicting: false,
|
||||
wantResult: true,
|
||||
},
|
||||
{
|
||||
name: "not available - conflict exists",
|
||||
conflicting: true,
|
||||
wantResult: false,
|
||||
},
|
||||
{
|
||||
name: "store error",
|
||||
storeErr: errors.New("db error"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) {
|
||||
return tt.conflicting, tt.storeErr
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123")
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantResult, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountAccountProxies(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
count int64
|
||||
storeErr error
|
||||
wantCount int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no proxies",
|
||||
count: 0,
|
||||
wantCount: 0,
|
||||
},
|
||||
{
|
||||
name: "one proxy",
|
||||
count: 1,
|
||||
wantCount: 1,
|
||||
},
|
||||
{
|
||||
name: "store error",
|
||||
storeErr: errors.New("db error"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) {
|
||||
return tt.count, tt.storeErr
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
count, err := mgr.CountAccountProxies(context.Background(), "acc-123")
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantCount, count)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountProxy(t *testing.T) {
|
||||
accountID := "acc-123"
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
expected := &proxy.Proxy{
|
||||
ID: "proxy-1",
|
||||
ClusterAddress: "byod.example.com",
|
||||
AccountID: &accountID,
|
||||
Status: proxy.StatusConnected,
|
||||
}
|
||||
s := &mockStore{
|
||||
getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) {
|
||||
assert.Equal(t, accountID, accID)
|
||||
return expected, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
p, err := mgr.GetAccountProxy(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, p)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) {
|
||||
return nil, errors.New("not found")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.GetAccountProxy(context.Background(), accountID)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteProxy(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
var deletedID string
|
||||
s := &mockStore{
|
||||
deleteProxyFunc: func(_ context.Context, proxyID string) error {
|
||||
deletedID = proxyID
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.DeleteProxy(context.Background(), "proxy-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "proxy-1", deletedID)
|
||||
})
|
||||
|
||||
t.Run("store error", func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
deleteProxyFunc: func(_ context.Context, _ string) error {
|
||||
return errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.DeleteProxy(context.Background(), "proxy-1")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetActiveClusterAddressesForAccount(t *testing.T) {
|
||||
expected := []string{"byod.example.com"}
|
||||
s := &mockStore{
|
||||
getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
assert.Equal(t, "acc-123", accID)
|
||||
return expected, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
@@ -51,17 +51,17 @@ func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interfac
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress)
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, accountID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Connect indicates an expected call of Connect.
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, accountID)
|
||||
}
|
||||
|
||||
// Disconnect mocks base method.
|
||||
@@ -93,6 +93,21 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
|
||||
}
|
||||
|
||||
// GetActiveClusterAddressesForAccount mocks base method.
|
||||
func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveClusterAddressesForAccount indicates an expected call of GetActiveClusterAddressesForAccount.
|
||||
func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID)
|
||||
}
|
||||
|
||||
// Heartbeat mocks base method.
|
||||
func (m *MockManager) Heartbeat(ctx context.Context, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -107,6 +122,65 @@ func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID interface{}) *gomock.C
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID)
|
||||
}
|
||||
|
||||
// GetAccountProxy mocks base method.
|
||||
func (m *MockManager) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAccountProxy", ctx, accountID)
|
||||
ret0, _ := ret[0].(*Proxy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAccountProxy indicates an expected call of GetAccountProxy.
|
||||
func (mr *MockManagerMockRecorder) GetAccountProxy(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountProxy", reflect.TypeOf((*MockManager)(nil).GetAccountProxy), ctx, accountID)
|
||||
}
|
||||
|
||||
// CountAccountProxies mocks base method.
|
||||
func (m *MockManager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountAccountProxies", ctx, accountID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountAccountProxies indicates an expected call of CountAccountProxies.
|
||||
func (mr *MockManagerMockRecorder) CountAccountProxies(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountProxies", reflect.TypeOf((*MockManager)(nil).CountAccountProxies), ctx, accountID)
|
||||
}
|
||||
|
||||
// IsClusterAddressAvailable mocks base method.
|
||||
func (m *MockManager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsClusterAddressAvailable", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsClusterAddressAvailable indicates an expected call of IsClusterAddressAvailable.
|
||||
func (mr *MockManagerMockRecorder) IsClusterAddressAvailable(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressAvailable", reflect.TypeOf((*MockManager)(nil).IsClusterAddressAvailable), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// DeleteProxy mocks base method.
|
||||
func (m *MockManager) DeleteProxy(ctx context.Context, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteProxy", ctx, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteProxy indicates an expected call of DeleteProxy.
|
||||
func (mr *MockManagerMockRecorder) DeleteProxy(ctx, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteProxy", reflect.TypeOf((*MockManager)(nil).DeleteProxy), ctx, proxyID)
|
||||
}
|
||||
|
||||
// MockController is a mock of Controller interface.
|
||||
type MockController struct {
|
||||
ctrl *gomock.Controller
|
||||
|
||||
@@ -2,11 +2,17 @@ package proxy
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
StatusConnected = "connected"
|
||||
StatusDisconnected = "disconnected"
|
||||
)
|
||||
|
||||
// Proxy represents a reverse proxy instance
|
||||
type Proxy struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)"`
|
||||
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||
IPAddress string `gorm:"type:varchar(45)"`
|
||||
AccountID *string `gorm:"type:varchar(255);uniqueIndex:idx_proxy_account_id_unique"`
|
||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||
ConnectedAt *time.Time
|
||||
DisconnectedAt *time.Time
|
||||
|
||||
184
management/internals/modules/reverseproxy/proxytoken/handler.go
Normal file
184
management/internals/modules/reverseproxy/proxytoken/handler.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package proxytoken
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(s store.Store, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
h := &handler{store: s, permissionsManager: permissionsManager}
|
||||
router.HandleFunc("/reverse-proxies/proxy-tokens", h.listTokens).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/proxy-tokens", h.createToken).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/proxy-tokens/{tokenId}", h.revokeToken).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.ProxyTokenRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" || len(req.Name) > 255 {
|
||||
util.WriteErrorResponse("name is required and must be at most 255 characters", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
var expiresIn time.Duration
|
||||
if req.ExpiresIn != nil && *req.ExpiresIn > 0 {
|
||||
expiresIn = time.Duration(*req.ExpiresIn) * time.Second
|
||||
}
|
||||
|
||||
accountID := userAuth.AccountId
|
||||
generated, err := types.CreateNewProxyAccessToken(req.Name, expiresIn, &accountID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to generate token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil {
|
||||
util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toProxyTokenCreatedResponse(generated)
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]api.ProxyToken, 0, len(tokens))
|
||||
for _, token := range tokens {
|
||||
resp = append(resp, toProxyTokenResponse(token))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
tokenID := mux.Vars(r)["tokenId"]
|
||||
if tokenID == "" {
|
||||
util.WriteErrorResponse("token ID is required", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
|
||||
return
|
||||
}
|
||||
|
||||
if token.AccountID == nil || *token.AccountID != userAuth.AccountId {
|
||||
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil {
|
||||
util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken {
|
||||
resp := api.ProxyToken{
|
||||
Id: token.ID,
|
||||
Name: token.Name,
|
||||
Revoked: token.Revoked,
|
||||
}
|
||||
if !token.CreatedAt.IsZero() {
|
||||
resp.CreatedAt = token.CreatedAt
|
||||
}
|
||||
if token.ExpiresAt != nil {
|
||||
resp.ExpiresAt = token.ExpiresAt
|
||||
}
|
||||
if token.LastUsed != nil {
|
||||
resp.LastUsed = token.LastUsed
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func toProxyTokenCreatedResponse(generated *types.ProxyAccessTokenGenerated) api.ProxyTokenCreated {
|
||||
base := toProxyTokenResponse(&generated.ProxyAccessToken)
|
||||
plainToken := string(generated.PlainToken)
|
||||
return api.ProxyTokenCreated{
|
||||
Id: base.Id,
|
||||
Name: base.Name,
|
||||
CreatedAt: base.CreatedAt,
|
||||
ExpiresAt: base.ExpiresAt,
|
||||
LastUsed: base.LastUsed,
|
||||
Revoked: base.Revoked,
|
||||
PlainToken: plainToken,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,275 @@
|
||||
package proxytoken
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func authContext(accountID, userID string) context.Context {
|
||||
return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateToken_AccountScoped(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
var savedToken *types.ProxyAccessToken
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, token *types.ProxyAccessToken) error {
|
||||
savedToken = token
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": "my-token"}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp api.ProxyTokenCreated
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
|
||||
assert.NotEmpty(t, resp.PlainToken)
|
||||
assert.Equal(t, "my-token", resp.Name)
|
||||
assert.False(t, resp.Revoked)
|
||||
|
||||
require.NotNil(t, savedToken)
|
||||
require.NotNil(t, savedToken.AccountID)
|
||||
assert.Equal(t, accountID, *savedToken.AccountID)
|
||||
assert.Equal(t, "user-1", savedToken.CreatedBy)
|
||||
}
|
||||
|
||||
func TestCreateToken_WithExpiration(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
var savedToken *types.ProxyAccessToken
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, token *types.ProxyAccessToken) error {
|
||||
savedToken = token
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": "expiring-token", "expires_in": 3600}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
require.NotNil(t, savedToken)
|
||||
require.NotNil(t, savedToken.ExpiresAt)
|
||||
assert.True(t, savedToken.ExpiresAt.After(time.Now()))
|
||||
}
|
||||
|
||||
func TestCreateToken_EmptyName(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": ""}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func TestCreateToken_PermissionDenied(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil)
|
||||
|
||||
h := &handler{
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": "test"}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
}
|
||||
|
||||
func TestListTokens(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
now := time.Now()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokensByAccountID(gomock.Any(), store.LockingStrengthNone, accountID).Return([]*types.ProxyAccessToken{
|
||||
{ID: "tok-1", Name: "token-1", AccountID: &accountID, CreatedAt: now, Revoked: false},
|
||||
{ID: "tok-2", Name: "token-2", AccountID: &accountID, CreatedAt: now, Revoked: true},
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/reverse-proxies/proxy-tokens", nil)
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.listTokens(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp []api.ProxyToken
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
require.Len(t, resp, 2)
|
||||
assert.Equal(t, "tok-1", resp[0].Id)
|
||||
assert.False(t, resp[0].Revoked)
|
||||
assert.Equal(t, "tok-2", resp[1].Id)
|
||||
assert.True(t, resp[1].Revoked)
|
||||
}
|
||||
|
||||
func TestRevokeToken_Success(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||
ID: "tok-1",
|
||||
Name: "test-token",
|
||||
AccountID: &accountID,
|
||||
}, nil)
|
||||
mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.revokeToken(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestRevokeToken_WrongAccount(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
otherAccount := "acc-other"
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||
ID: "tok-1",
|
||||
AccountID: &otherAccount,
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.revokeToken(w, req)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestRevokeToken_ManagementWideToken(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||
ID: "tok-1",
|
||||
AccountID: nil,
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.revokeToken(w, req)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
package selfhostedproxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// ProxyDisconnector can force-disconnect a connected proxy's gRPC stream.
|
||||
type ProxyDisconnector interface {
|
||||
ForceDisconnect(proxyID string)
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
proxyMgr proxy.Manager
|
||||
serviceMgr rpservice.Manager
|
||||
permissionsManager permissions.Manager
|
||||
disconnector ProxyDisconnector
|
||||
}
|
||||
|
||||
func RegisterEndpoints(proxyMgr proxy.Manager, serviceMgr rpservice.Manager, permissionsManager permissions.Manager, disconnector ProxyDisconnector, router *mux.Router) {
|
||||
h := &handler{
|
||||
proxyMgr: proxyMgr,
|
||||
serviceMgr: serviceMgr,
|
||||
permissionsManager: permissionsManager,
|
||||
disconnector: disconnector,
|
||||
}
|
||||
router.HandleFunc("/reverse-proxies/self-hosted-proxies", h.listProxies).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/self-hosted-proxies/{proxyId}", h.deleteProxy).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) listProxies(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
p, err := h.proxyMgr.GetAccountProxy(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
if isNotFound(err) {
|
||||
util.WriteJSONObject(r.Context(), w, []api.SelfHostedProxy{})
|
||||
return
|
||||
}
|
||||
util.WriteErrorResponse("failed to get proxy", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
serviceCount := 0
|
||||
services, err := h.serviceMgr.GetAccountServices(r.Context(), userAuth.AccountId)
|
||||
if err == nil {
|
||||
for _, svc := range services {
|
||||
if svc.ProxyCluster == p.ClusterAddress {
|
||||
serviceCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp := []api.SelfHostedProxy{toSelfHostedProxyResponse(p, serviceCount)}
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func (h *handler) deleteProxy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
proxyID := mux.Vars(r)["proxyId"]
|
||||
if proxyID == "" {
|
||||
util.WriteErrorResponse("proxy ID is required", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
p, err := h.proxyMgr.GetAccountProxy(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("proxy not found", http.StatusNotFound, w)
|
||||
return
|
||||
}
|
||||
|
||||
if p.ID != proxyID {
|
||||
util.WriteErrorResponse("proxy not found", http.StatusNotFound, w)
|
||||
return
|
||||
}
|
||||
|
||||
if h.disconnector != nil {
|
||||
h.disconnector.ForceDisconnect(proxyID)
|
||||
}
|
||||
|
||||
if err := h.proxyMgr.DeleteProxy(r.Context(), proxyID); err != nil {
|
||||
util.WriteErrorResponse("failed to delete proxy", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func isNotFound(err error) bool {
|
||||
e, ok := status.FromError(err)
|
||||
return ok && e.Type() == status.NotFound
|
||||
}
|
||||
|
||||
func toSelfHostedProxyResponse(p *proxy.Proxy, serviceCount int) api.SelfHostedProxy {
|
||||
st := api.SelfHostedProxyStatus(p.Status)
|
||||
resp := api.SelfHostedProxy{
|
||||
Id: p.ID,
|
||||
ClusterAddress: p.ClusterAddress,
|
||||
Status: st,
|
||||
LastSeen: p.LastSeen,
|
||||
ServiceCount: serviceCount,
|
||||
}
|
||||
if p.IPAddress != "" {
|
||||
resp.IpAddress = &p.IPAddress
|
||||
}
|
||||
if p.ConnectedAt != nil {
|
||||
resp.ConnectedAt = p.ConnectedAt
|
||||
}
|
||||
return resp
|
||||
}
|
||||
@@ -0,0 +1,220 @@
|
||||
package selfhostedproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type mockDisconnector struct {
|
||||
disconnectedIDs []string
|
||||
}
|
||||
|
||||
func (m *mockDisconnector) ForceDisconnect(proxyID string) {
|
||||
m.disconnectedIDs = append(m.disconnectedIDs, proxyID)
|
||||
}
|
||||
|
||||
func authContext(accountID, userID string) context.Context {
|
||||
return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
})
|
||||
}
|
||||
|
||||
func TestListProxies_Success(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
now := time.Now()
|
||||
connAt := now.Add(-1 * time.Hour)
|
||||
|
||||
proxyMgr := proxy.NewMockManager(ctrl)
|
||||
proxyMgr.EXPECT().GetAccountProxy(gomock.Any(), accountID).Return(&proxy.Proxy{
|
||||
ID: "proxy-1",
|
||||
ClusterAddress: "byod.example.com",
|
||||
IPAddress: "10.0.0.1",
|
||||
AccountID: &accountID,
|
||||
Status: proxy.StatusConnected,
|
||||
LastSeen: now,
|
||||
ConnectedAt: &connAt,
|
||||
}, nil)
|
||||
|
||||
serviceMgr := rpservice.NewMockManager(ctrl)
|
||||
serviceMgr.EXPECT().GetAccountServices(gomock.Any(), accountID).Return([]*rpservice.Service{
|
||||
{ProxyCluster: "byod.example.com"},
|
||||
{ProxyCluster: "byod.example.com"},
|
||||
{ProxyCluster: "other.cluster.com"},
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
proxyMgr: proxyMgr,
|
||||
serviceMgr: serviceMgr,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/reverse-proxies/self-hosted-proxies", nil)
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.listProxies(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp []api.SelfHostedProxy
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
require.Len(t, resp, 1)
|
||||
assert.Equal(t, "proxy-1", resp[0].Id)
|
||||
assert.Equal(t, "byod.example.com", resp[0].ClusterAddress)
|
||||
assert.Equal(t, 2, resp[0].ServiceCount)
|
||||
assert.Equal(t, api.SelfHostedProxyStatus(proxy.StatusConnected), resp[0].Status)
|
||||
}
|
||||
|
||||
func TestListProxies_NoProxy(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
proxyMgr := proxy.NewMockManager(ctrl)
|
||||
proxyMgr.EXPECT().GetAccountProxy(gomock.Any(), "acc-123").Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Read).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
proxyMgr: proxyMgr,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/reverse-proxies/self-hosted-proxies", nil)
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.listProxies(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp []api.SelfHostedProxy
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
assert.Empty(t, resp)
|
||||
}
|
||||
|
||||
func TestListProxies_PermissionDenied(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Read).Return(false, nil)
|
||||
|
||||
h := &handler{
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/reverse-proxies/self-hosted-proxies", nil)
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.listProxies(w, req)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
}
|
||||
|
||||
func TestDeleteProxy_Success(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
disconnector := &mockDisconnector{}
|
||||
|
||||
proxyMgr := proxy.NewMockManager(ctrl)
|
||||
proxyMgr.EXPECT().GetAccountProxy(gomock.Any(), accountID).Return(&proxy.Proxy{
|
||||
ID: "proxy-1",
|
||||
AccountID: &accountID,
|
||||
Status: proxy.StatusConnected,
|
||||
}, nil)
|
||||
proxyMgr.EXPECT().DeleteProxy(gomock.Any(), "proxy-1").Return(nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
proxyMgr: proxyMgr,
|
||||
permissionsManager: permsMgr,
|
||||
disconnector: disconnector,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/self-hosted-proxies/proxy-1", nil)
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"proxyId": "proxy-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.deleteProxy(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, disconnector.disconnectedIDs, "proxy-1")
|
||||
}
|
||||
|
||||
func TestDeleteProxy_WrongProxyID(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
|
||||
proxyMgr := proxy.NewMockManager(ctrl)
|
||||
proxyMgr.EXPECT().GetAccountProxy(gomock.Any(), accountID).Return(&proxy.Proxy{
|
||||
ID: "proxy-1",
|
||||
AccountID: &accountID,
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
proxyMgr: proxyMgr,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/self-hosted-proxies/proxy-other", nil)
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"proxyId": "proxy-other"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.deleteProxy(w, req)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestDeleteProxy_PermissionDenied(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(false, nil)
|
||||
|
||||
h := &handler{
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/self-hosted-proxies/proxy-1", nil)
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"proxyId": "proxy-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.deleteProxy(w, req)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
}
|
||||
@@ -25,4 +25,5 @@ type Manager interface {
|
||||
RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
|
||||
StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
|
||||
StartExposeReaper(ctx context.Context)
|
||||
GetServiceByDomain(ctx context.Context, domain string) (*Service, error)
|
||||
}
|
||||
|
||||
@@ -122,6 +122,21 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||
ret0, _ := ret[0].(*Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
|
||||
}
|
||||
|
||||
// GetGlobalServices mocks base method.
|
||||
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -627,6 +627,10 @@ func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
|
||||
@@ -425,7 +425,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
return srv
|
||||
}
|
||||
|
||||
@@ -706,7 +706,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
require.NoError(t, err)
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
@@ -1138,7 +1138,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
Reference in New Issue
Block a user