[management, reverse proxy] Add reverse proxy feature (#5291)

* implement reverse proxy


---------

Co-authored-by: Alisdair MacLeod <git@alisdairmacleod.co.uk>
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
Co-authored-by: Eduard Gert <kontakt@eduardgert.de>
Co-authored-by: Viktor Liu <viktor@netbird.io>
Co-authored-by: Diego Noguês <diego.sure@gmail.com>
Co-authored-by: Diego Noguês <49420+diegocn@users.noreply.github.com>
Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
Co-authored-by: Ashley Mensah <ashleyamo982@gmail.com>
This commit is contained in:
Pascal Fischer
2026-02-13 19:37:43 +01:00
committed by GitHub
parent edce11b34d
commit f53155562f
225 changed files with 35513 additions and 235 deletions

View File

@@ -0,0 +1,575 @@
package roundtrip
import (
"context"
"errors"
"fmt"
"net/http"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/client/embed"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
const deviceNamePrefix = "ingress-proxy-"
// backendKey identifies a backend by its host:port from the target URL.
type backendKey = string
var (
// ErrNoAccountID is returned when a request context is missing the account ID.
ErrNoAccountID = errors.New("no account ID in request context")
// ErrNoPeerConnection is returned when no embedded client exists for the account.
ErrNoPeerConnection = errors.New("no peer connection found")
// ErrClientStartFailed is returned when the embedded client fails to start.
ErrClientStartFailed = errors.New("client start failed")
// ErrTooManyInflight is returned when the per-backend in-flight limit is reached.
ErrTooManyInflight = errors.New("too many in-flight requests")
)
// domainInfo holds metadata about a registered domain.
type domainInfo struct {
serviceID string
}
type domainNotification struct {
domain domain.Domain
serviceID string
}
// clientEntry holds an embedded NetBird client and tracks which domains use it.
type clientEntry struct {
client *embed.Client
transport *http.Transport
domains map[domain.Domain]domainInfo
createdAt time.Time
started bool
// Per-backend in-flight limiting keyed by target host:port.
// TODO: clean up stale entries when backend targets change.
inflightMu sync.Mutex
inflightMap map[backendKey]chan struct{}
maxInflight int
}
// acquireInflight attempts to acquire an in-flight slot for the given backend.
// It returns a release function that must always be called, and true on success.
func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bool) {
noop := func() {}
if e.maxInflight <= 0 {
return noop, true
}
e.inflightMu.Lock()
sem, exists := e.inflightMap[backend]
if !exists {
sem = make(chan struct{}, e.maxInflight)
e.inflightMap[backend] = sem
}
e.inflightMu.Unlock()
select {
case sem <- struct{}{}:
return func() { <-sem }, true
default:
return noop, false
}
}
type statusNotifier interface {
NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error
}
type managementClient interface {
CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest, opts ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error)
}
// NetBird provides an http.RoundTripper implementation
// backed by underlying NetBird connections.
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
type NetBird struct {
mgmtAddr string
proxyID string
proxyAddr string
wgPort int
logger *log.Logger
mgmtClient managementClient
transportCfg transportConfig
clientsMux sync.RWMutex
clients map[types.AccountID]*clientEntry
initLogOnce sync.Once
statusNotifier statusNotifier
}
// ClientDebugInfo contains debug information about a client.
type ClientDebugInfo struct {
AccountID types.AccountID
DomainCount int
Domains domain.List
HasClient bool
CreatedAt time.Time
}
// accountIDContextKey is the context key for storing the account ID.
type accountIDContextKey struct{}
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
// one is created by authenticating with the management server using the provided token.
// Multiple domains can share the same client.
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) error {
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
// Client already exists for this account, just register the domain
entry.domains[d] = domainInfo{serviceID: serviceID}
started := entry.started
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Debug("registered domain with existing client")
// If client is already started, notify this domain as connected immediately
if started && n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), serviceID, string(d), true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).WithError(err).Warn("failed to notify status for existing client")
}
}
return nil
}
entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID)
if err != nil {
n.clientsMux.Unlock()
return err
}
n.clients[accountID] = entry
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Info("created new client for account")
// Attempt to start the client in the background; if this fails we will
// retry on the first request via RoundTrip.
go n.runClientStartup(ctx, accountID, entry.client)
return nil
}
// createClientEntry generates a WireGuard keypair, authenticates with management,
// and creates an embedded NetBird client. Must be called with clientsMux held.
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": serviceID,
}).Debug("generating WireGuard keypair for new peer")
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("generate wireguard private key: %w", err)
}
publicKey := privateKey.PublicKey()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": serviceID,
"public_key": publicKey.String(),
}).Debug("authenticating new proxy peer with management")
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
ServiceId: serviceID,
AccountId: string(accountID),
Token: authToken,
WireguardPublicKey: publicKey.String(),
Cluster: n.proxyAddr,
})
if err != nil {
return nil, fmt.Errorf("authenticate proxy peer with management: %w", err)
}
if resp != nil && !resp.GetSuccess() {
errMsg := "unknown error"
if resp.ErrorMessage != nil {
errMsg = *resp.ErrorMessage
}
return nil, fmt.Errorf("proxy peer authentication failed: %s", errMsg)
}
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": serviceID,
"public_key": publicKey.String(),
}).Info("proxy peer authenticated successfully with management")
n.initLogOnce.Do(func() {
if err := util.InitLog(log.WarnLevel.String(), util.LogConsole); err != nil {
n.logger.WithField("account_id", accountID).Warnf("failed to initialize embedded client logging: %v", err)
}
})
// Create embedded NetBird client with the generated private key.
// The peer has already been created via CreateProxyPeer RPC with the public key.
client, err := embed.New(embed.Options{
DeviceName: deviceNamePrefix + n.proxyID,
ManagementURL: n.mgmtAddr,
PrivateKey: privateKey.String(),
LogLevel: log.WarnLevel.String(),
BlockInbound: true,
WireguardPort: &n.wgPort,
})
if err != nil {
return nil, fmt.Errorf("create netbird client: %w", err)
}
// Create a transport using the client dialer. We do this instead of using
// the client's HTTPClient to avoid issues with request validation that do
// not work with reverse proxied requests.
return &clientEntry{
client: client,
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
transport: &http.Transport{
DialContext: client.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: n.transportCfg.maxIdleConns,
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
MaxConnsPerHost: n.transportCfg.maxConnsPerHost,
IdleConnTimeout: n.transportCfg.idleConnTimeout,
TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout,
ExpectContinueTimeout: n.transportCfg.expectContinueTimeout,
ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout,
WriteBufferSize: n.transportCfg.writeBufferSize,
ReadBufferSize: n.transportCfg.readBufferSize,
DisableCompression: n.transportCfg.disableCompression,
},
createdAt: time.Now(),
started: false,
inflightMap: make(map[backendKey]chan struct{}),
maxInflight: n.transportCfg.maxInflight,
}, nil
}
// runClientStartup starts the client and notifies registered domains on success.
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := client.Start(startCtx); err != nil {
if errors.Is(err, context.DeadlineExceeded) {
n.logger.WithField("account_id", accountID).Warn("netbird client start timed out, will retry on first request")
} else {
n.logger.WithField("account_id", accountID).WithError(err).Error("failed to start netbird client")
}
return
}
// Mark client as started and collect domains to notify outside the lock.
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
entry.started = true
}
var domainsToNotify []domainNotification
if exists {
for dom, info := range entry.domains {
domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID})
}
}
n.clientsMux.Unlock()
if n.statusNotifier == nil {
return
}
for _, dn := range domainsToNotify {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": dn.domain,
}).WithError(err).Warn("failed to notify tunnel connection status")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": dn.domain,
}).Info("notified management about tunnel connection")
}
}
}
// RemovePeer unregisters a domain from an account. The client is only stopped
// when no domains are using it anymore.
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error {
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if !exists {
n.clientsMux.Unlock()
n.logger.WithField("account_id", accountID).Debug("remove peer: account not found")
return nil
}
// Get domain info before deleting
domInfo, domainExists := entry.domains[d]
if !domainExists {
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Debug("remove peer: domain not registered")
return nil
}
delete(entry.domains, d)
// If there are still domains using this client, keep it running
if len(entry.domains) > 0 {
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
"remaining_domains": len(entry.domains),
}).Debug("unregistered domain, client still in use")
// Notify this domain as disconnected
if n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).WithError(err).Warn("failed to notify tunnel disconnection status")
}
}
return nil
}
// No more domains using this client, stop it
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).Info("stopping client, no more domains")
client := entry.client
transport := entry.transport
delete(n.clients, accountID)
n.clientsMux.Unlock()
// Notify disconnection before stopping
if n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).WithError(err).Warn("failed to notify tunnel disconnection status")
}
}
transport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).WithError(err).Warn("failed to stop netbird client")
}
return nil
}
// RoundTrip implements http.RoundTripper. It looks up the client for the account
// specified in the request context and uses it to dial the backend.
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
accountID := AccountIDFromContext(req.Context())
if accountID == "" {
return nil, ErrNoAccountID
}
// Copy references while holding lock, then unlock early to avoid blocking
// other requests during the potentially slow RoundTrip.
n.clientsMux.RLock()
entry, exists := n.clients[accountID]
if !exists {
n.clientsMux.RUnlock()
return nil, fmt.Errorf("%w for account: %s", ErrNoPeerConnection, accountID)
}
client := entry.client
transport := entry.transport
n.clientsMux.RUnlock()
release, ok := entry.acquireInflight(req.URL.Host)
defer release()
if !ok {
return nil, ErrTooManyInflight
}
// Attempt to start the client, if the client is already running then
// it will return an error that we ignore, if this hits a timeout then
// this request is unprocessable.
startCtx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
defer cancel()
if err := client.Start(startCtx); err != nil {
if !errors.Is(err, embed.ErrClientAlreadyStarted) {
return nil, fmt.Errorf("%w: %w", ErrClientStartFailed, err)
}
}
start := time.Now()
resp, err := transport.RoundTrip(req)
duration := time.Since(start)
if err != nil {
n.logger.Debugf("roundtrip: method=%s host=%s url=%s account=%s duration=%s err=%v",
req.Method, req.Host, req.URL.String(), accountID, duration.Truncate(time.Millisecond), err)
return nil, err
}
n.logger.Debugf("roundtrip: method=%s host=%s url=%s account=%s status=%d duration=%s",
req.Method, req.Host, req.URL.String(), accountID, resp.StatusCode, duration.Truncate(time.Millisecond))
return resp, nil
}
// StopAll stops all clients.
func (n *NetBird) StopAll(ctx context.Context) error {
n.clientsMux.Lock()
defer n.clientsMux.Unlock()
var merr *multierror.Error
for accountID, entry := range n.clients {
entry.transport.CloseIdleConnections()
if err := entry.client.Stop(ctx); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).WithError(err).Warn("failed to stop netbird client during shutdown")
merr = multierror.Append(merr, err)
}
}
maps.Clear(n.clients)
return nberrors.FormatErrorOrNil(merr)
}
// HasClient returns true if there is a client for the given account.
func (n *NetBird) HasClient(accountID types.AccountID) bool {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
_, exists := n.clients[accountID]
return exists
}
// DomainCount returns the number of domains registered for the given account.
// Returns 0 if the account has no client.
func (n *NetBird) DomainCount(accountID types.AccountID) int {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
entry, exists := n.clients[accountID]
if !exists {
return 0
}
return len(entry.domains)
}
// ClientCount returns the total number of active clients.
func (n *NetBird) ClientCount() int {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
return len(n.clients)
}
// GetClient returns the embed.Client for the given account ID.
func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
entry, exists := n.clients[accountID]
if !exists {
return nil, false
}
return entry.client, true
}
// ListClientsForDebug returns information about all clients for debug purposes.
func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
result := make(map[types.AccountID]ClientDebugInfo)
for accountID, entry := range n.clients {
domains := make(domain.List, 0, len(entry.domains))
for d := range entry.domains {
domains = append(domains, d)
}
result[accountID] = ClientDebugInfo{
AccountID: accountID,
DomainCount: len(entry.domains),
Domains: domains,
HasClient: entry.client != nil,
CreatedAt: entry.createdAt,
}
}
return result
}
// ListClientsForStartup returns all embed.Client instances for health checks.
func (n *NetBird) ListClientsForStartup() map[types.AccountID]*embed.Client {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
result := make(map[types.AccountID]*embed.Client)
for accountID, entry := range n.clients {
if entry.client != nil {
result[accountID] = entry.client
}
}
return result
}
// NewNetBird creates a new NetBird transport. Set wgPort to 0 for a random
// OS-assigned port. A fixed port only works with single-account deployments;
// multiple accounts will fail to bind the same port.
func NewNetBird(mgmtAddr, proxyID, proxyAddr string, wgPort int, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird {
if logger == nil {
logger = log.StandardLogger()
}
return &NetBird{
mgmtAddr: mgmtAddr,
proxyID: proxyID,
proxyAddr: proxyAddr,
wgPort: wgPort,
logger: logger,
clients: make(map[types.AccountID]*clientEntry),
statusNotifier: notifier,
mgmtClient: mgmtClient,
transportCfg: loadTransportConfig(logger),
}
}
// WithAccountID adds the account ID to the context.
func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context {
return context.WithValue(ctx, accountIDContextKey{}, accountID)
}
// AccountIDFromContext retrieves the account ID from the context.
func AccountIDFromContext(ctx context.Context) types.AccountID {
v := ctx.Value(accountIDContextKey{})
if v == nil {
return ""
}
accountID, ok := v.(types.AccountID)
if !ok {
return ""
}
return accountID
}

View File

@@ -0,0 +1,107 @@
package roundtrip
import (
"crypto/rand"
"math/big"
"sync"
"testing"
"time"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
)
// Simple benchmark for comparison with AddPeer contention.
func BenchmarkHasClient(b *testing.B) {
// Knobs for dialling in:
initialClientCount := 100 // Size of initial peer map to generate.
nb := mockNetBird()
var target types.AccountID
targetIndex, err := rand.Int(rand.Reader, big.NewInt(int64(initialClientCount)))
if err != nil {
b.Fatal(err)
}
for i := range initialClientCount {
id := types.AccountID(rand.Text())
if int64(i) == targetIndex.Int64() {
target = id
}
nb.clients[id] = &clientEntry{
domains: map[domain.Domain]domainInfo{
domain.Domain(rand.Text()): {
serviceID: rand.Text(),
},
},
createdAt: time.Now(),
started: true,
}
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
nb.HasClient(target)
}
})
b.StopTimer()
}
func BenchmarkHasClientDuringAddPeer(b *testing.B) {
// Knobs for dialling in:
initialClientCount := 100 // Size of initial peer map to generate.
addPeerWorkers := 5 // Number of workers to concurrently call AddPeer.
nb := mockNetBird()
// Add random client entries to the netbird instance.
// We're trying to test map lock contention, so starting with
// a populated map should help with this.
// Pick a random one to target for retrieval later.
var target types.AccountID
targetIndex, err := rand.Int(rand.Reader, big.NewInt(int64(initialClientCount)))
if err != nil {
b.Fatal(err)
}
for i := range initialClientCount {
id := types.AccountID(rand.Text())
if int64(i) == targetIndex.Int64() {
target = id
}
nb.clients[id] = &clientEntry{
domains: map[domain.Domain]domainInfo{
domain.Domain(rand.Text()): {
serviceID: rand.Text(),
},
},
createdAt: time.Now(),
started: true,
}
}
// Launch workers that continuously call AddPeer with new random accountIDs.
var wg sync.WaitGroup
for range addPeerWorkers {
wg.Go(func() {
for {
if err := nb.AddPeer(b.Context(),
types.AccountID(rand.Text()),
domain.Domain(rand.Text()),
rand.Text(),
rand.Text()); err != nil {
b.Log(err)
}
}
})
}
// Benchmark calling HasClient during AddPeer contention.
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
nb.HasClient(target)
}
})
b.StopTimer()
}

View File

@@ -0,0 +1,328 @@
package roundtrip
import (
"context"
"net/http"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
)
type mockMgmtClient struct{}
func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
return &proto.CreateProxyPeerResponse{Success: true}, nil
}
type mockStatusNotifier struct {
mu sync.Mutex
statuses []statusCall
}
type statusCall struct {
accountID string
serviceID string
domain string
connected bool
}
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error {
m.mu.Lock()
defer m.mu.Unlock()
m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected})
return nil
}
func (m *mockStatusNotifier) calls() []statusCall {
m.mu.Lock()
defer m.mu.Unlock()
return append([]statusCall{}, m.statuses...)
}
// mockNetBird creates a NetBird instance for testing without actually connecting.
// It uses an invalid management URL to prevent real connections.
func mockNetBird() *NetBird {
return NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, nil, &mockMgmtClient{})
}
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Initially no client exists.
assert.False(t, nb.HasClient(accountID), "should not have client before AddPeer")
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
// Add first domain - this should create a new client.
// Note: This will fail to actually connect since we use an invalid URL,
// but the client entry should still be created.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID), "should have client after AddPeer")
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
}
func TestNetBird_AddPeer_ReuseClientForSameAccount(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add first domain.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
assert.Equal(t, 1, nb.DomainCount(accountID))
// Add second domain for the same account - should reuse existing client.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
require.NoError(t, err)
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2 after adding second domain")
// Add third domain.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
require.NoError(t, err)
assert.Equal(t, 3, nb.DomainCount(accountID), "domain count should be 3 after adding third domain")
// Still only one client.
assert.True(t, nb.HasClient(accountID))
}
func TestNetBird_AddPeer_SeparateClientsForDifferentAccounts(t *testing.T) {
nb := mockNetBird()
account1 := types.AccountID("account-1")
account2 := types.AccountID("account-2")
// Add domain for account 1.
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
// Add domain for account 2.
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "setup-key-2", "proxy-2")
require.NoError(t, err)
// Both accounts should have their own clients.
assert.True(t, nb.HasClient(account1), "account1 should have client")
assert.True(t, nb.HasClient(account2), "account2 should have client")
assert.Equal(t, 1, nb.DomainCount(account1), "account1 domain count should be 1")
assert.Equal(t, 1, nb.DomainCount(account2), "account2 domain count should be 1")
}
func TestNetBird_RemovePeer_KeepsClientWhenDomainsRemain(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add multiple domains.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
require.NoError(t, err)
assert.Equal(t, 3, nb.DomainCount(accountID))
// Remove one domain - client should remain.
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID), "client should remain after removing one domain")
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2")
// Remove another domain - client should still remain.
err = nb.RemovePeer(context.Background(), accountID, "domain2.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID), "client should remain after removing second domain")
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
}
func TestNetBird_RemovePeer_RemovesClientWhenLastDomainRemoved(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add single domain.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID))
// Remove the only domain - client should be removed.
// Note: Stop() may fail since the client never actually connected,
// but the entry should still be removed from the map.
_ = nb.RemovePeer(context.Background(), accountID, "domain1.test")
// After removing all domains, client should be gone.
assert.False(t, nb.HasClient(accountID), "client should be removed after removing last domain")
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
}
func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("nonexistent-account")
// Removing from non-existent account should not error.
err := nb.RemovePeer(context.Background(), accountID, "domain1.test")
assert.NoError(t, err, "removing from non-existent account should not error")
}
func TestNetBird_RemovePeer_NonExistentDomainIsNoop(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add one domain.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
// Remove non-existent domain - should not affect existing domain.
err = nb.RemovePeer(context.Background(), accountID, domain.Domain("nonexistent.test"))
require.NoError(t, err)
// Original domain should still be registered.
assert.True(t, nb.HasClient(accountID))
assert.Equal(t, 1, nb.DomainCount(accountID), "original domain should remain")
}
func TestWithAccountID_AndAccountIDFromContext(t *testing.T) {
ctx := context.Background()
accountID := types.AccountID("test-account")
// Initially no account ID in context.
retrieved := AccountIDFromContext(ctx)
assert.True(t, retrieved == "", "should be empty when not set")
// Add account ID to context.
ctx = WithAccountID(ctx, accountID)
retrieved = AccountIDFromContext(ctx)
assert.Equal(t, accountID, retrieved, "should retrieve the same account ID")
}
func TestAccountIDFromContext_ReturnsEmptyForWrongType(t *testing.T) {
// Create context with wrong type for account ID key.
ctx := context.WithValue(context.Background(), accountIDContextKey{}, "wrong-type-string")
retrieved := AccountIDFromContext(ctx)
assert.True(t, retrieved == "", "should return empty for wrong type")
}
func TestNetBird_StopAll_StopsAllClients(t *testing.T) {
nb := mockNetBird()
account1 := types.AccountID("account-1")
account2 := types.AccountID("account-2")
account3 := types.AccountID("account-3")
// Add domains for multiple accounts.
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "key-1", "proxy-1")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "key-2", "proxy-2")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), account3, domain.Domain("domain3.test"), "key-3", "proxy-3")
require.NoError(t, err)
assert.Equal(t, 3, nb.ClientCount(), "should have 3 clients")
// Stop all clients.
// Note: StopAll may return errors since clients never actually connected,
// but the clients should still be removed from the map.
_ = nb.StopAll(context.Background())
assert.Equal(t, 0, nb.ClientCount(), "should have 0 clients after StopAll")
assert.False(t, nb.HasClient(account1), "account1 should not have client")
assert.False(t, nb.HasClient(account2), "account2 should not have client")
assert.False(t, nb.HasClient(account3), "account3 should not have client")
}
func TestNetBird_ClientCount(t *testing.T) {
nb := mockNetBird()
assert.Equal(t, 0, nb.ClientCount(), "should start with 0 clients")
// Add clients for different accounts.
err := nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1.test"), "key-1", "proxy-1")
require.NoError(t, err)
assert.Equal(t, 1, nb.ClientCount())
err = nb.AddPeer(context.Background(), types.AccountID("account-2"), domain.Domain("domain2.test"), "key-2", "proxy-2")
require.NoError(t, err)
assert.Equal(t, 2, nb.ClientCount())
// Adding domain to existing account should not increase count.
err = nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1b.test"), "key-1", "proxy-1b")
require.NoError(t, err)
assert.Equal(t, 2, nb.ClientCount(), "adding domain to existing account should not increase client count")
}
func TestNetBird_RoundTrip_RequiresAccountIDInContext(t *testing.T) {
nb := mockNetBird()
// Create a request without account ID in context.
req, err := http.NewRequest("GET", "http://example.com/", nil)
require.NoError(t, err)
// RoundTrip should fail because no account ID in context.
_, err = nb.RoundTrip(req) //nolint:bodyclose
require.ErrorIs(t, err, ErrNoAccountID)
}
func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("nonexistent-account")
// Create a request with account ID but no client exists.
req, err := http.NewRequest("GET", "http://example.com/", nil)
require.NoError(t, err)
req = req.WithContext(WithAccountID(req.Context(), accountID))
// RoundTrip should fail because no client for this account.
_, err = nb.RoundTrip(req) //nolint:bodyclose // Error case, no response body
assert.Error(t, err)
assert.Contains(t, err.Error(), "no peer connection found for account")
}
func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1")
// Add first domain — creates a new client entry.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
require.NoError(t, err)
// Manually mark client as started to simulate background startup completing.
nb.clientsMux.Lock()
nb.clients[accountID].started = true
nb.clientsMux.Unlock()
// Add second domain — should notify immediately since client is already started.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
require.NoError(t, err)
calls := notifier.calls()
require.Len(t, calls, 1)
assert.Equal(t, string(accountID), calls[0].accountID)
assert.Equal(t, "svc-2", calls[0].serviceID)
assert.Equal(t, "domain2.test", calls[0].domain)
assert.True(t, calls[0].connected)
}
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1")
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
require.NoError(t, err)
// Remove one domain — client stays, but disconnection notification fires.
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID))
calls := notifier.calls()
require.Len(t, calls, 1)
assert.Equal(t, "domain1.test", calls[0].domain)
assert.False(t, calls[0].connected)
}

View File

@@ -0,0 +1,152 @@
package roundtrip
import (
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
)
// Environment variable names for tuning the backend HTTP transport.
const (
EnvMaxIdleConns = "NB_PROXY_MAX_IDLE_CONNS"
EnvMaxIdleConnsPerHost = "NB_PROXY_MAX_IDLE_CONNS_PER_HOST"
EnvMaxConnsPerHost = "NB_PROXY_MAX_CONNS_PER_HOST"
EnvIdleConnTimeout = "NB_PROXY_IDLE_CONN_TIMEOUT"
EnvTLSHandshakeTimeout = "NB_PROXY_TLS_HANDSHAKE_TIMEOUT"
EnvExpectContinueTimeout = "NB_PROXY_EXPECT_CONTINUE_TIMEOUT"
EnvResponseHeaderTimeout = "NB_PROXY_RESPONSE_HEADER_TIMEOUT"
EnvWriteBufferSize = "NB_PROXY_WRITE_BUFFER_SIZE"
EnvReadBufferSize = "NB_PROXY_READ_BUFFER_SIZE"
EnvDisableCompression = "NB_PROXY_DISABLE_COMPRESSION"
EnvMaxInflight = "NB_PROXY_MAX_INFLIGHT"
)
// transportConfig holds tunable parameters for the per-account HTTP transport.
type transportConfig struct {
maxIdleConns int
maxIdleConnsPerHost int
maxConnsPerHost int
idleConnTimeout time.Duration
tlsHandshakeTimeout time.Duration
expectContinueTimeout time.Duration
responseHeaderTimeout time.Duration
writeBufferSize int
readBufferSize int
disableCompression bool
// maxInflight limits per-backend concurrent requests. 0 means unlimited.
maxInflight int
}
func defaultTransportConfig() transportConfig {
return transportConfig{
maxIdleConns: 100,
maxIdleConnsPerHost: 100,
maxConnsPerHost: 0, // unlimited
idleConnTimeout: 90 * time.Second,
tlsHandshakeTimeout: 10 * time.Second,
expectContinueTimeout: 1 * time.Second,
}
}
func loadTransportConfig(logger *log.Logger) transportConfig {
cfg := defaultTransportConfig()
if v, ok := envInt(EnvMaxIdleConns, logger); ok {
cfg.maxIdleConns = v
}
if v, ok := envInt(EnvMaxIdleConnsPerHost, logger); ok {
cfg.maxIdleConnsPerHost = v
}
if v, ok := envInt(EnvMaxConnsPerHost, logger); ok {
cfg.maxConnsPerHost = v
}
if v, ok := envDuration(EnvIdleConnTimeout, logger); ok {
cfg.idleConnTimeout = v
}
if v, ok := envDuration(EnvTLSHandshakeTimeout, logger); ok {
cfg.tlsHandshakeTimeout = v
}
if v, ok := envDuration(EnvExpectContinueTimeout, logger); ok {
cfg.expectContinueTimeout = v
}
if v, ok := envDuration(EnvResponseHeaderTimeout, logger); ok {
cfg.responseHeaderTimeout = v
}
if v, ok := envInt(EnvWriteBufferSize, logger); ok {
cfg.writeBufferSize = v
}
if v, ok := envInt(EnvReadBufferSize, logger); ok {
cfg.readBufferSize = v
}
if v, ok := envBool(EnvDisableCompression, logger); ok {
cfg.disableCompression = v
}
if v, ok := envInt(EnvMaxInflight, logger); ok {
cfg.maxInflight = v
}
logger.WithFields(log.Fields{
"max_idle_conns": cfg.maxIdleConns,
"max_idle_conns_per_host": cfg.maxIdleConnsPerHost,
"max_conns_per_host": cfg.maxConnsPerHost,
"idle_conn_timeout": cfg.idleConnTimeout,
"tls_handshake_timeout": cfg.tlsHandshakeTimeout,
"expect_continue_timeout": cfg.expectContinueTimeout,
"response_header_timeout": cfg.responseHeaderTimeout,
"write_buffer_size": cfg.writeBufferSize,
"read_buffer_size": cfg.readBufferSize,
"disable_compression": cfg.disableCompression,
"max_inflight": cfg.maxInflight,
}).Debug("backend transport configuration")
return cfg
}
func envInt(key string, logger *log.Logger) (int, bool) {
s := os.Getenv(key)
if s == "" {
return 0, false
}
v, err := strconv.Atoi(s)
if err != nil {
logger.Warnf("failed to parse %s=%q as int: %v", key, s, err)
return 0, false
}
if v < 0 {
logger.Warnf("ignoring negative value for %s=%d", key, v)
return 0, false
}
return v, true
}
func envDuration(key string, logger *log.Logger) (time.Duration, bool) {
s := os.Getenv(key)
if s == "" {
return 0, false
}
v, err := time.ParseDuration(s)
if err != nil {
logger.Warnf("failed to parse %s=%q as duration: %v", key, s, err)
return 0, false
}
if v < 0 {
logger.Warnf("ignoring negative value for %s=%s", key, v)
return 0, false
}
return v, true
}
func envBool(key string, logger *log.Logger) (bool, bool) {
s := os.Getenv(key)
if s == "" {
return false, false
}
v, err := strconv.ParseBool(s)
if err != nil {
logger.Warnf("failed to parse %s=%q as bool: %v", key, s, err)
return false, false
}
return v, true
}