[proxy] feature: bring your own proxy

This commit is contained in:
crn4
2026-03-17 13:17:50 +01:00
parent 5585adce18
commit 26ba03f08e
32 changed files with 2697 additions and 94 deletions

View File

@@ -31,6 +31,7 @@ type store interface {
type proxyManager interface {
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
}
type Manager struct {
@@ -68,8 +69,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
var ret []*domain.Domain
// Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
// For BYOD accounts, only their own cluster is returned; otherwise shared clusters.
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
@@ -112,8 +113,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
return nil, status.NewPermissionDeniedError()
}
// Verify the target cluster is in the available clusters
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
// Verify the target cluster is in the available clusters for this account
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
@@ -259,7 +260,7 @@ func (m Manager) GetClusterDomains() []string {
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
@@ -284,6 +285,17 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
}
func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) {
byodAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get BYOD cluster addresses: %w", err)
}
if len(byodAddresses) > 0 {
return byodAddresses, nil
}
return m.proxyManager.GetActiveClusterAddresses(ctx)
}
func extractClusterFromCustomDomains(domain string, customDomains []*domain.Domain) (string, bool) {
for _, customDomain := range customDomains {
if strings.HasSuffix(domain, "."+customDomain.Domain) {

View File

@@ -0,0 +1,144 @@
package manager
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockProxyManager struct {
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
}
func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
if m.getActiveClusterAddressesFunc != nil {
return m.getActiveClusterAddressesFunc(ctx)
}
return nil, nil
}
func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
if m.getActiveClusterAddressesForAccountFunc != nil {
return m.getActiveClusterAddressesForAccountFunc(ctx, accountID)
}
return nil, nil
}
func TestGetClusterAllowList_BYODProxy(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
assert.Equal(t, "acc-123", accID)
return []string{"byod.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
t.Fatal("should not call GetActiveClusterAddresses when BYOD addresses exist")
return nil, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"byod.example.com"}, result)
}
func TestGetClusterAllowList_NoBYOD_FallbackToShared(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return nil, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_BYODError_FallbackToShared(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return nil, errors.New("db error")
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_BYODEmptySlice_FallbackToShared(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
}
func TestExtractClusterFromFreeDomain(t *testing.T) {
clusters := []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}
tests := []struct {
name string
domain string
wantCluster string
wantOK bool
}{
{
name: "matches EU cluster",
domain: "myapp.abc123.eu.proxy.netbird.io",
wantCluster: "eu.proxy.netbird.io",
wantOK: true,
},
{
name: "matches US cluster",
domain: "myapp.xyz789.us.proxy.netbird.io",
wantCluster: "us.proxy.netbird.io",
wantOK: true,
},
{
name: "no match - custom domain",
domain: "app.example.com",
wantOK: false,
},
{
name: "no match - partial cluster name",
domain: "proxy.netbird.io",
wantOK: false,
},
{
name: "exact cluster name - no prefix",
domain: "eu.proxy.netbird.io",
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cluster, ok := ExtractClusterFromFreeDomain(tt.domain, clusters)
assert.Equal(t, tt.wantOK, ok)
if tt.wantOK {
assert.Equal(t, tt.wantCluster, cluster)
}
})
}
}