mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-24 19:26:39 +00:00
Use a 1:1 mapping of netbird client to netbird account
- Add debug endpoint for monitoring netbird clients - Add types package with AccountID type - Refactor netbird roundtrip to key clients by AccountID - Multiple domains can share the same client per account - Add status notifier for tunnel connection updates - Add OIDC flags to CLI - Add tests for netbird client management
This commit is contained in:
@@ -4,18 +4,39 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const deviceNamePrefix = "ingress-"
|
||||
const deviceNamePrefix = "ingress-proxy-"
|
||||
|
||||
// ErrNoAccountID is returned when a request context is missing the account ID.
|
||||
var ErrNoAccountID = errors.New("no account ID in request context")
|
||||
|
||||
// domainInfo holds metadata about a registered domain.
|
||||
type domainInfo struct {
|
||||
reverseProxyID string
|
||||
}
|
||||
|
||||
// clientEntry holds an embedded NetBird client and tracks which domains use it.
|
||||
type clientEntry struct {
|
||||
client *embed.Client
|
||||
transport *http.Transport
|
||||
domains map[domain.Domain]domainInfo
|
||||
createdAt time.Time
|
||||
started bool
|
||||
}
|
||||
|
||||
type statusNotifier interface {
|
||||
NotifyStatus(ctx context.Context, accountID, reverseProxyID, domain string, connected bool) error
|
||||
@@ -23,147 +44,389 @@ type statusNotifier interface {
|
||||
|
||||
// NetBird provides an http.RoundTripper implementation
|
||||
// backed by underlying NetBird connections.
|
||||
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
|
||||
type NetBird struct {
|
||||
mgmtAddr string
|
||||
proxyID string
|
||||
logger *log.Logger
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
clients map[string]*embed.Client
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
clients map[types.AccountID]*clientEntry
|
||||
initLogOnce sync.Once
|
||||
statusNotifier statusNotifier
|
||||
}
|
||||
|
||||
func NewNetBird(mgmtAddr string, logger *log.Logger, notifier statusNotifier) *NetBird {
|
||||
// NewNetBird creates a new NetBird transport.
|
||||
func NewNetBird(mgmtAddr, proxyID string, logger *log.Logger, notifier statusNotifier) *NetBird {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
return &NetBird{
|
||||
mgmtAddr: mgmtAddr,
|
||||
proxyID: proxyID,
|
||||
logger: logger,
|
||||
clients: make(map[string]*embed.Client),
|
||||
clients: make(map[types.AccountID]*clientEntry),
|
||||
statusNotifier: notifier,
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NetBird) AddPeer(ctx context.Context, domain, key, accountID, reverseProxyID string) error {
|
||||
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
|
||||
// one is created using the provided setup key. Multiple domains can share the same client.
|
||||
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, key, reverseProxyID string) error {
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
// Client already exists for this account, just register the domain
|
||||
entry.domains[d] = domainInfo{reverseProxyID: reverseProxyID}
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Debug("registered domain with existing client")
|
||||
|
||||
// If client is already started, notify this domain as connected immediately
|
||||
if started && n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), reverseProxyID, string(d), true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).WithError(err).Warn("failed to notify status for existing client")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
n.initLogOnce.Do(func() {
|
||||
if err := util.InitLog(log.WarnLevel.String(), util.LogConsole); err != nil {
|
||||
n.logger.WithField("account_id", accountID).Warnf("failed to initialize embedded client logging: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
wgPort := 0
|
||||
client, err := embed.New(embed.Options{
|
||||
DeviceName: deviceNamePrefix + domain,
|
||||
DeviceName: deviceNamePrefix + n.proxyID,
|
||||
ManagementURL: n.mgmtAddr,
|
||||
SetupKey: key,
|
||||
LogOutput: io.Discard,
|
||||
LogLevel: log.WarnLevel.String(),
|
||||
BlockInbound: true,
|
||||
WireguardPort: &wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return fmt.Errorf("create netbird client: %w", err)
|
||||
}
|
||||
|
||||
// Create a transport using the client dialer. We do this instead of using
|
||||
// the client's HTTPClient to avoid issues with request validation that do
|
||||
// not work with reverse proxied requests.
|
||||
entry = &clientEntry{
|
||||
client: client,
|
||||
domains: map[domain.Domain]domainInfo{d: {reverseProxyID: reverseProxyID}},
|
||||
transport: &http.Transport{
|
||||
DialContext: client.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
},
|
||||
createdAt: time.Now(),
|
||||
started: false,
|
||||
}
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background, if this fails
|
||||
// then it is not ideal, but it isn't the end of the world because
|
||||
// we will try to start the client again before we use it.
|
||||
go func() {
|
||||
startCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
err = client.Start(startCtx)
|
||||
switch {
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
n.logger.Debug("netbird client timed out")
|
||||
// This is not ideal, but we will try again later.
|
||||
return
|
||||
case err != nil:
|
||||
n.logger.WithField("domain", domain).WithError(err).Error("Unable to start netbird client, will try again later.")
|
||||
|
||||
if err := client.Start(startCtx); err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).Debug("netbird client start timed out, will retry on first request")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).WithError(err).Error("failed to start netbird client")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Notify management that tunnel is now active
|
||||
// Mark client as started and notify all registered domains
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.started = true
|
||||
}
|
||||
// Copy domain info while holding lock
|
||||
var domainsToNotify []struct {
|
||||
domain domain.Domain
|
||||
reverseProxyID string
|
||||
}
|
||||
if exists {
|
||||
for dom, info := range entry.domains {
|
||||
domainsToNotify = append(domainsToNotify, struct {
|
||||
domain domain.Domain
|
||||
reverseProxyID string
|
||||
}{domain: dom, reverseProxyID: info.reverseProxyID})
|
||||
}
|
||||
}
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
// Notify all domains that they're connected
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, reverseProxyID, domain, true); err != nil {
|
||||
n.logger.WithField("domain", domain).WithError(err).Warn("Failed to notify management about tunnel connection")
|
||||
} else {
|
||||
n.logger.WithField("domain", domain).Info("Successfully notified management about tunnel connection")
|
||||
for _, domInfo := range domainsToNotify {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(domInfo.domain), true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": domInfo.domain,
|
||||
}).WithError(err).Warn("failed to notify tunnel connection status")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": domInfo.domain,
|
||||
}).Info("notified management about tunnel connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
n.clients[domain] = client
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, domain, accountID, reverseProxyID string) error {
|
||||
n.clientsMux.RLock()
|
||||
client, exists := n.clients[domain]
|
||||
n.clientsMux.RUnlock()
|
||||
// RemovePeer unregisters a domain from an account. The client is only stopped
|
||||
// when no domains are using it anymore.
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error {
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
// Mission failed successfully!
|
||||
n.clientsMux.Unlock()
|
||||
return nil
|
||||
}
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
return fmt.Errorf("stop netbird client: %w", err)
|
||||
|
||||
// Get domain info before deleting
|
||||
domInfo, domainExists := entry.domains[d]
|
||||
if !domainExists {
|
||||
n.clientsMux.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Notify management that tunnel is disconnected
|
||||
delete(entry.domains, d)
|
||||
|
||||
// If there are still domains using this client, keep it running
|
||||
if len(entry.domains) > 0 {
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
"remaining_domains": len(entry.domains),
|
||||
}).Debug("unregistered domain, client still in use")
|
||||
|
||||
// Notify this domain as disconnected
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(d), false); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// No more domains using this client, stop it
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).Info("stopping client, no more domains")
|
||||
|
||||
client := entry.client
|
||||
transport := entry.transport
|
||||
delete(n.clients, accountID)
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
// Notify disconnection before stopping
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, reverseProxyID, domain, false); err != nil {
|
||||
n.logger.WithField("domain", domain).WithError(err).Warn("Failed to notify management about tunnel disconnection")
|
||||
} else {
|
||||
n.logger.WithField("domain", domain).Info("Successfully notified management about tunnel disconnection")
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(d), false); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
||||
}
|
||||
}
|
||||
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
delete(n.clients, domain)
|
||||
transport.CloseIdleConnections()
|
||||
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper. It looks up the client for the account
|
||||
// specified in the request context and uses it to dial the backend.
|
||||
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
host, _, err := net.SplitHostPort(req.Host)
|
||||
if err != nil {
|
||||
host = req.Host
|
||||
accountID := AccountIDFromContext(req.Context())
|
||||
if accountID == "" {
|
||||
return nil, ErrNoAccountID
|
||||
}
|
||||
|
||||
// Copy references while holding lock, then unlock early to avoid blocking
|
||||
// other requests during the potentially slow RoundTrip.
|
||||
n.clientsMux.RLock()
|
||||
client, exists := n.clients[host]
|
||||
// Immediately unlock after retrieval here rather than defer to avoid
|
||||
// the call to client.Do blocking other clients being used whilst one
|
||||
// is in use.
|
||||
n.clientsMux.RUnlock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no peer connection found for host: %s", host)
|
||||
n.clientsMux.RUnlock()
|
||||
return nil, fmt.Errorf("no peer connection found for account: %s", accountID)
|
||||
}
|
||||
client := entry.client
|
||||
transport := entry.transport
|
||||
n.clientsMux.RUnlock()
|
||||
|
||||
// Attempt to start the client, if the client is already running then
|
||||
// it will return an error that we ignore, if this hits a timeout then
|
||||
// this request is unprocessable.
|
||||
startCtx, cancel := context.WithTimeout(req.Context(), 3*time.Second)
|
||||
startCtx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
err = client.Start(startCtx)
|
||||
switch {
|
||||
case errors.Is(err, embed.ErrClientAlreadyStarted):
|
||||
break
|
||||
case err != nil:
|
||||
return nil, fmt.Errorf("start netbird client: %w", err)
|
||||
if err := client.Start(startCtx); err != nil {
|
||||
if !errors.Is(err, embed.ErrClientAlreadyStarted) {
|
||||
return nil, fmt.Errorf("start netbird client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"host": host,
|
||||
"account_id": accountID,
|
||||
"host": req.Host,
|
||||
"url": req.URL.String(),
|
||||
"requestURI": req.RequestURI,
|
||||
"method": req.Method,
|
||||
}).Debug("running roundtrip for peer connection")
|
||||
|
||||
// Create a new transport using the client dialer and perform the roundtrip.
|
||||
// We do this instead of using the client HTTPClient to avoid issues around
|
||||
// client request validation that do not work with the reverse proxied
|
||||
// requests.
|
||||
// Other values are simply copied from the http.DefaultTransport which the
|
||||
// standard reverse proxy implementation would have used.
|
||||
// TODO: tune this transport for our needs.
|
||||
return (&http.Transport{
|
||||
DialContext: client.DialContext,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}).RoundTrip(req)
|
||||
return transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// StopAll stops all clients.
|
||||
func (n *NetBird) StopAll(ctx context.Context) error {
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for accountID, entry := range n.clients {
|
||||
entry.transport.CloseIdleConnections()
|
||||
if err := entry.client.Stop(ctx); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).WithError(err).Warn("failed to stop netbird client during shutdown")
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
maps.Clear(n.clients)
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// HasClient returns true if there is a client for the given account.
|
||||
func (n *NetBird) HasClient(accountID types.AccountID) bool {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
_, exists := n.clients[accountID]
|
||||
return exists
|
||||
}
|
||||
|
||||
// DomainCount returns the number of domains registered for the given account.
|
||||
// Returns 0 if the account has no client.
|
||||
func (n *NetBird) DomainCount(accountID types.AccountID) int {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
return len(entry.domains)
|
||||
}
|
||||
|
||||
// ClientCount returns the total number of active clients.
|
||||
func (n *NetBird) ClientCount() int {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
return len(n.clients)
|
||||
}
|
||||
|
||||
// GetClient returns the embed.Client for the given account ID.
|
||||
func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
return entry.client, true
|
||||
}
|
||||
|
||||
// ClientDebugInfo contains debug information about a client.
|
||||
type ClientDebugInfo struct {
|
||||
AccountID types.AccountID
|
||||
DomainCount int
|
||||
Domains domain.List
|
||||
HasClient bool
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// ListClientsForDebug returns information about all clients for debug purposes.
|
||||
func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
|
||||
result := make(map[types.AccountID]ClientDebugInfo)
|
||||
for accountID, entry := range n.clients {
|
||||
domains := make(domain.List, 0, len(entry.domains))
|
||||
for d := range entry.domains {
|
||||
domains = append(domains, d)
|
||||
}
|
||||
result[accountID] = ClientDebugInfo{
|
||||
AccountID: accountID,
|
||||
DomainCount: len(entry.domains),
|
||||
Domains: domains,
|
||||
HasClient: entry.client != nil,
|
||||
CreatedAt: entry.createdAt,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// accountIDContextKey is the context key for storing the account ID.
|
||||
type accountIDContextKey struct{}
|
||||
|
||||
// WithAccountID adds the account ID to the context.
|
||||
func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context {
|
||||
return context.WithValue(ctx, accountIDContextKey{}, accountID)
|
||||
}
|
||||
|
||||
// AccountIDFromContext retrieves the account ID from the context.
|
||||
func AccountIDFromContext(ctx context.Context) types.AccountID {
|
||||
v := ctx.Value(accountIDContextKey{})
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
accountID, ok := v.(types.AccountID)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return accountID
|
||||
}
|
||||
|
||||
247
proxy/internal/roundtrip/netbird_test.go
Normal file
247
proxy/internal/roundtrip/netbird_test.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// mockNetBird creates a NetBird instance for testing without actually connecting.
|
||||
// It uses an invalid management URL to prevent real connections.
|
||||
func mockNetBird() *NetBird {
|
||||
return NewNetBird("http://invalid.test:9999", "test-proxy", nil, nil)
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Initially no client exists.
|
||||
assert.False(t, nb.HasClient(accountID), "should not have client before AddPeer")
|
||||
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
|
||||
|
||||
// Add first domain - this should create a new client.
|
||||
// Note: This will fail to actually connect since we use an invalid URL,
|
||||
// but the client entry should still be created.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, nb.HasClient(accountID), "should have client after AddPeer")
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_ReuseClientForSameAccount(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add first domain.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID))
|
||||
|
||||
// Add second domain for the same account - should reuse existing client.
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2 after adding second domain")
|
||||
|
||||
// Add third domain.
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, nb.DomainCount(accountID), "domain count should be 3 after adding third domain")
|
||||
|
||||
// Still only one client.
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_SeparateClientsForDifferentAccounts(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
account1 := types.AccountID("account-1")
|
||||
account2 := types.AccountID("account-2")
|
||||
|
||||
// Add domain for account 1.
|
||||
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add domain for account 2.
|
||||
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "setup-key-2", "proxy-2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Both accounts should have their own clients.
|
||||
assert.True(t, nb.HasClient(account1), "account1 should have client")
|
||||
assert.True(t, nb.HasClient(account2), "account2 should have client")
|
||||
assert.Equal(t, 1, nb.DomainCount(account1), "account1 domain count should be 1")
|
||||
assert.Equal(t, 1, nb.DomainCount(account2), "account2 domain count should be 1")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_KeepsClientWhenDomainsRemain(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add multiple domains.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, nb.DomainCount(accountID))
|
||||
|
||||
// Remove one domain - client should remain.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID), "client should remain after removing one domain")
|
||||
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2")
|
||||
|
||||
// Remove another domain - client should still remain.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "domain2.test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID), "client should remain after removing second domain")
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_RemovesClientWhenLastDomainRemoved(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add single domain.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
|
||||
// Remove the only domain - client should be removed.
|
||||
// Note: Stop() may fail since the client never actually connected,
|
||||
// but the entry should still be removed from the map.
|
||||
_ = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
|
||||
// After removing all domains, client should be gone.
|
||||
assert.False(t, nb.HasClient(accountID), "client should be removed after removing last domain")
|
||||
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("nonexistent-account")
|
||||
|
||||
// Removing from non-existent account should not error.
|
||||
err := nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
assert.NoError(t, err, "removing from non-existent account should not error")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_NonExistentDomainIsNoop(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add one domain.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove non-existent domain - should not affect existing domain.
|
||||
err = nb.RemovePeer(context.Background(), accountID, domain.Domain("nonexistent.test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Original domain should still be registered.
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID), "original domain should remain")
|
||||
}
|
||||
|
||||
func TestWithAccountID_AndAccountIDFromContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := types.AccountID("test-account")
|
||||
|
||||
// Initially no account ID in context.
|
||||
retrieved := AccountIDFromContext(ctx)
|
||||
assert.True(t, retrieved == "", "should be empty when not set")
|
||||
|
||||
// Add account ID to context.
|
||||
ctx = WithAccountID(ctx, accountID)
|
||||
retrieved = AccountIDFromContext(ctx)
|
||||
assert.Equal(t, accountID, retrieved, "should retrieve the same account ID")
|
||||
}
|
||||
|
||||
func TestAccountIDFromContext_ReturnsEmptyForWrongType(t *testing.T) {
|
||||
// Create context with wrong type for account ID key.
|
||||
ctx := context.WithValue(context.Background(), accountIDContextKey{}, "wrong-type-string")
|
||||
|
||||
retrieved := AccountIDFromContext(ctx)
|
||||
assert.True(t, retrieved == "", "should return empty for wrong type")
|
||||
}
|
||||
|
||||
func TestNetBird_StopAll_StopsAllClients(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
account1 := types.AccountID("account-1")
|
||||
account2 := types.AccountID("account-2")
|
||||
account3 := types.AccountID("account-3")
|
||||
|
||||
// Add domains for multiple accounts.
|
||||
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "key-2", "proxy-2")
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), account3, domain.Domain("domain3.test"), "key-3", "proxy-3")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 3, nb.ClientCount(), "should have 3 clients")
|
||||
|
||||
// Stop all clients.
|
||||
// Note: StopAll may return errors since clients never actually connected,
|
||||
// but the clients should still be removed from the map.
|
||||
_ = nb.StopAll(context.Background())
|
||||
|
||||
assert.Equal(t, 0, nb.ClientCount(), "should have 0 clients after StopAll")
|
||||
assert.False(t, nb.HasClient(account1), "account1 should not have client")
|
||||
assert.False(t, nb.HasClient(account2), "account2 should not have client")
|
||||
assert.False(t, nb.HasClient(account3), "account3 should not have client")
|
||||
}
|
||||
|
||||
func TestNetBird_ClientCount(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
|
||||
assert.Equal(t, 0, nb.ClientCount(), "should start with 0 clients")
|
||||
|
||||
// Add clients for different accounts.
|
||||
err := nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1.test"), "key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, nb.ClientCount())
|
||||
|
||||
err = nb.AddPeer(context.Background(), types.AccountID("account-2"), domain.Domain("domain2.test"), "key-2", "proxy-2")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, nb.ClientCount())
|
||||
|
||||
// Adding domain to existing account should not increase count.
|
||||
err = nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1b.test"), "key-1", "proxy-1b")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, nb.ClientCount(), "adding domain to existing account should not increase client count")
|
||||
}
|
||||
|
||||
func TestNetBird_RoundTrip_RequiresAccountIDInContext(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
|
||||
// Create a request without account ID in context.
|
||||
req, err := http.NewRequest("GET", "http://example.com/", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// RoundTrip should fail because no account ID in context.
|
||||
_, err = nb.RoundTrip(req)
|
||||
require.ErrorIs(t, err, ErrNoAccountID)
|
||||
}
|
||||
|
||||
func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("nonexistent-account")
|
||||
|
||||
// Create a request with account ID but no client exists.
|
||||
req, err := http.NewRequest("GET", "http://example.com/", nil)
|
||||
require.NoError(t, err)
|
||||
req = req.WithContext(WithAccountID(req.Context(), accountID))
|
||||
|
||||
// RoundTrip should fail because no client for this account.
|
||||
_, err = nb.RoundTrip(req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no peer connection found for account")
|
||||
}
|
||||
Reference in New Issue
Block a user