mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-14 12:49:57 +00:00
[proxy] feature: bring your own proxy (#5627)
This commit is contained in:
@@ -16,11 +16,16 @@ type store interface {
|
||||
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
|
||||
}
|
||||
|
||||
// Manager handles all proxy operations
|
||||
@@ -44,7 +49,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
||||
|
||||
// Connect registers a new proxy connection in the database.
|
||||
// capabilities may be nil for old proxies that do not report them.
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||
func (m *Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||
now := time.Now()
|
||||
var caps proxy.Capabilities
|
||||
if capabilities != nil {
|
||||
@@ -55,9 +60,10 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
|
||||
SessionID: sessionID,
|
||||
ClusterAddress: clusterAddress,
|
||||
IPAddress: ipAddress,
|
||||
AccountID: accountID,
|
||||
LastSeen: now,
|
||||
ConnectedAt: &now,
|
||||
Status: "connected",
|
||||
Status: proxy.StatusConnected,
|
||||
Capabilities: caps,
|
||||
}
|
||||
|
||||
@@ -77,7 +83,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
|
||||
}
|
||||
|
||||
// Disconnect marks a proxy as disconnected in the database.
|
||||
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||
func (m *Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
|
||||
return err
|
||||
@@ -92,7 +98,7 @@ func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) erro
|
||||
}
|
||||
|
||||
// Heartbeat updates the proxy's last seen timestamp.
|
||||
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
func (m *Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||
return err
|
||||
@@ -104,7 +110,7 @@ func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
|
||||
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
@@ -113,16 +119,6 @@ func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
|
||||
func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) {
|
||||
clusters, err := m.store.GetActiveProxyClusters(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return clusters, nil
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts returns whether any active proxy in the cluster
|
||||
// supports custom ports. Returns nil when no proxy has reported capabilities.
|
||||
func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
||||
@@ -142,10 +138,44 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
|
||||
}
|
||||
|
||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err)
|
||||
return nil, err
|
||||
}
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
return m.store.GetProxyByAccountID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
|
||||
return m.store.CountProxiesByAccountID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return !conflicting, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
if err := m.store.DeleteAccountCluster(ctx, clusterAddress, accountID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete cluster %s for account %s: %v", clusterAddress, accountID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,337 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
)
|
||||
|
||||
type mockStore struct {
|
||||
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
|
||||
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
||||
}
|
||||
|
||||
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
if m.saveProxyFunc != nil {
|
||||
return m.saveProxyFunc(ctx, p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||
if m.disconnectProxyFunc != nil {
|
||||
return m.disconnectProxyFunc(ctx, proxyID, sessionID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
if m.updateProxyHeartbeatFunc != nil {
|
||||
return m.updateProxyHeartbeatFunc(ctx, p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
if m.getActiveProxyClusterAddressesFunc != nil {
|
||||
return m.getActiveProxyClusterAddressesFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
if m.getActiveProxyClusterAddressesForAccFunc != nil {
|
||||
return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
|
||||
if m.cleanupStaleProxiesFunc != nil {
|
||||
return m.cleanupStaleProxiesFunc(ctx, d)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
if m.getProxyByAccountIDFunc != nil {
|
||||
return m.getProxyByAccountIDFunc(ctx, accountID)
|
||||
}
|
||||
return nil, fmt.Errorf("proxy not found for account %s", accountID)
|
||||
}
|
||||
func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
|
||||
if m.countProxiesByAccountIDFunc != nil {
|
||||
return m.countProxiesByAccountIDFunc(ctx, accountID)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
if m.isClusterAddressConflictingFunc != nil {
|
||||
return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
func (m *mockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||
if m.deleteAccountClusterFunc != nil {
|
||||
return m.deleteAccountClusterFunc(ctx, clusterAddress, accountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestManager(s store) *Manager {
|
||||
meter := noop.NewMeterProvider().Meter("test")
|
||||
m, err := NewManager(s, meter)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func TestConnect_WithAccountID(t *testing.T) {
|
||||
accountID := "acc-123"
|
||||
|
||||
var savedProxy *proxy.Proxy
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
|
||||
savedProxy = p
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", &accountID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, savedProxy)
|
||||
assert.Equal(t, "proxy-1", savedProxy.ID)
|
||||
assert.Equal(t, "session-1", savedProxy.SessionID)
|
||||
assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress)
|
||||
assert.Equal(t, "10.0.0.1", savedProxy.IPAddress)
|
||||
assert.Equal(t, &accountID, savedProxy.AccountID)
|
||||
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
|
||||
assert.NotNil(t, savedProxy.ConnectedAt)
|
||||
}
|
||||
|
||||
func TestConnect_WithoutAccountID(t *testing.T) {
|
||||
var savedProxy *proxy.Proxy
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
|
||||
savedProxy = p
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "eu.proxy.netbird.io", "10.0.0.1", nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, savedProxy)
|
||||
assert.Nil(t, savedProxy.AccountID)
|
||||
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
|
||||
}
|
||||
|
||||
func TestConnect_StoreError(t *testing.T) {
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error {
|
||||
return errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", nil, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestIsClusterAddressAvailable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
conflicting bool
|
||||
storeErr error
|
||||
wantResult bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "available - no conflict",
|
||||
conflicting: false,
|
||||
wantResult: true,
|
||||
},
|
||||
{
|
||||
name: "not available - conflict exists",
|
||||
conflicting: true,
|
||||
wantResult: false,
|
||||
},
|
||||
{
|
||||
name: "store error",
|
||||
storeErr: errors.New("db error"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) {
|
||||
return tt.conflicting, tt.storeErr
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123")
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantResult, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountAccountProxies(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
count int64
|
||||
storeErr error
|
||||
wantCount int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no proxies",
|
||||
count: 0,
|
||||
wantCount: 0,
|
||||
},
|
||||
{
|
||||
name: "one proxy",
|
||||
count: 1,
|
||||
wantCount: 1,
|
||||
},
|
||||
{
|
||||
name: "store error",
|
||||
storeErr: errors.New("db error"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) {
|
||||
return tt.count, tt.storeErr
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
count, err := mgr.CountAccountProxies(context.Background(), "acc-123")
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantCount, count)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountProxy(t *testing.T) {
|
||||
accountID := "acc-123"
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
expected := &proxy.Proxy{
|
||||
ID: "proxy-1",
|
||||
ClusterAddress: "byop.example.com",
|
||||
AccountID: &accountID,
|
||||
Status: proxy.StatusConnected,
|
||||
}
|
||||
s := &mockStore{
|
||||
getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) {
|
||||
assert.Equal(t, accountID, accID)
|
||||
return expected, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
p, err := mgr.GetAccountProxy(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, p)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) {
|
||||
return nil, errors.New("not found")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.GetAccountProxy(context.Background(), accountID)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteAccountCluster(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
var deletedCluster, deletedAccount string
|
||||
s := &mockStore{
|
||||
deleteAccountClusterFunc: func(_ context.Context, clusterAddress, accountID string) error {
|
||||
deletedCluster = clusterAddress
|
||||
deletedAccount = accountID
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "cluster.example.com", deletedCluster)
|
||||
assert.Equal(t, "acc-123", deletedAccount)
|
||||
})
|
||||
|
||||
t.Run("store error", func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
deleteAccountClusterFunc: func(_ context.Context, _, _ string) error {
|
||||
return errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetActiveClusterAddressesForAccount(t *testing.T) {
|
||||
expected := []string{"byop.example.com"}
|
||||
s := &mockStore{
|
||||
getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
assert.Equal(t, "acc-123", accID)
|
||||
return expected, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
Reference in New Issue
Block a user