[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)

This commit is contained in:
Viktor Liu
2026-03-14 01:36:44 +08:00
committed by GitHub
parent fe9b844511
commit 3e6baea405
90 changed files with 9611 additions and 1397 deletions

View File

@@ -5,6 +5,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"sync"
"time"
@@ -14,11 +15,12 @@ import (
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/embed"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
@@ -26,7 +28,22 @@ import (
const deviceNamePrefix = "ingress-proxy-"
// backendKey identifies a backend by its host:port from the target URL.
type backendKey = string
type backendKey string
// ServiceKey uniquely identifies a service (HTTP reverse proxy or L4 service)
// that holds a reference to an embedded NetBird client. Callers should use the
// DomainServiceKey and L4ServiceKey constructors to avoid namespace collisions.
type ServiceKey string
// DomainServiceKey returns a ServiceKey for an HTTP/TLS domain-based service.
func DomainServiceKey(domain string) ServiceKey {
return ServiceKey("domain:" + domain)
}
// L4ServiceKey returns a ServiceKey for an L4 service (TCP/UDP).
func L4ServiceKey(id types.ServiceID) ServiceKey {
return ServiceKey("l4:" + id)
}
var (
// ErrNoAccountID is returned when a request context is missing the account ID.
@@ -39,24 +56,24 @@ var (
ErrTooManyInflight = errors.New("too many in-flight requests")
)
// domainInfo holds metadata about a registered domain.
type domainInfo struct {
serviceID string
// serviceInfo holds metadata about a registered service.
type serviceInfo struct {
serviceID types.ServiceID
}
type domainNotification struct {
domain domain.Domain
serviceID string
type serviceNotification struct {
key ServiceKey
serviceID types.ServiceID
}
// clientEntry holds an embedded NetBird client and tracks which domains use it.
// clientEntry holds an embedded NetBird client and tracks which services use it.
type clientEntry struct {
client *embed.Client
transport *http.Transport
// insecureTransport is a clone of transport with TLS verification disabled,
// used when per-target skip_tls_verify is set.
insecureTransport *http.Transport
domains map[domain.Domain]domainInfo
services map[ServiceKey]serviceInfo
createdAt time.Time
started bool
// Per-backend in-flight limiting keyed by target host:port.
@@ -93,12 +110,12 @@ func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bo
// ClientConfig holds configuration for the embedded NetBird client.
type ClientConfig struct {
MgmtAddr string
WGPort int
WGPort uint16
PreSharedKey string
}
type statusNotifier interface {
NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error
NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error
}
type managementClient interface {
@@ -107,7 +124,7 @@ type managementClient interface {
// NetBird provides an http.RoundTripper implementation
// backed by underlying NetBird connections.
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
// Clients are keyed by AccountID, allowing multiple services to share the same connection.
type NetBird struct {
proxyID string
proxyAddr string
@@ -124,11 +141,11 @@ type NetBird struct {
// ClientDebugInfo contains debug information about a client.
type ClientDebugInfo struct {
AccountID types.AccountID
DomainCount int
Domains domain.List
HasClient bool
CreatedAt time.Time
AccountID types.AccountID
ServiceCount int
ServiceKeys []string
HasClient bool
CreatedAt time.Time
}
// accountIDContextKey is the context key for storing the account ID.
@@ -137,37 +154,37 @@ type accountIDContextKey struct{}
// skipTLSVerifyContextKey is the context key for requesting insecure TLS.
type skipTLSVerifyContextKey struct{}
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
// AddPeer registers a service for an account. If the account doesn't have a client yet,
// one is created by authenticating with the management server using the provided token.
// Multiple domains can share the same client.
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) error {
// Multiple services can share the same client.
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
si := serviceInfo{serviceID: serviceID}
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
// Client already exists for this account, just register the domain
entry.domains[d] = domainInfo{serviceID: serviceID}
entry.services[key] = si
started := entry.started
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Debug("registered domain with existing client")
"account_id": accountID,
"service_key": key,
}).Debug("registered service with existing client")
// If client is already started, notify this domain as connected immediately
if started && n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), serviceID, string(d), true); err != nil {
if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
"account_id": accountID,
"service_key": key,
}).WithError(err).Warn("failed to notify status for existing client")
}
}
return nil
}
entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID)
entry, err := n.createClientEntry(ctx, accountID, key, authToken, si)
if err != nil {
n.clientsMux.Unlock()
return err
@@ -177,8 +194,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
"account_id": accountID,
"service_key": key,
}).Info("created new client for account")
// Attempt to start the client in the background; if this fails we will
@@ -190,7 +207,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
// createClientEntry generates a WireGuard keypair, authenticates with management,
// and creates an embedded NetBird client. Must be called with clientsMux held.
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) {
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
serviceID := si.serviceID
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": serviceID,
@@ -209,7 +227,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
}).Debug("authenticating new proxy peer with management")
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
ServiceId: serviceID,
ServiceId: string(serviceID),
AccountId: string(accountID),
Token: authToken,
WireguardPublicKey: publicKey.String(),
@@ -240,13 +258,14 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
// Create embedded NetBird client with the generated private key.
// The peer has already been created via CreateProxyPeer RPC with the public key.
wgPort := int(n.clientCfg.WGPort)
client, err := embed.New(embed.Options{
DeviceName: deviceNamePrefix + n.proxyID,
ManagementURL: n.clientCfg.MgmtAddr,
PrivateKey: privateKey.String(),
LogLevel: log.WarnLevel.String(),
BlockInbound: true,
WireguardPort: &n.clientCfg.WGPort,
WireguardPort: &wgPort,
PreSharedKey: n.clientCfg.PreSharedKey,
})
if err != nil {
@@ -257,7 +276,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
// the client's HTTPClient to avoid issues with request validation that do
// not work with reverse proxied requests.
transport := &http.Transport{
DialContext: client.DialContext,
DialContext: dialWithTimeout(client.DialContext),
ForceAttemptHTTP2: true,
MaxIdleConns: n.transportCfg.maxIdleConns,
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
@@ -276,7 +295,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
return &clientEntry{
client: client,
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
services: map[ServiceKey]serviceInfo{key: si},
transport: transport,
insecureTransport: insecureTransport,
createdAt: time.Now(),
@@ -286,7 +305,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
}, nil
}
// runClientStartup starts the client and notifies registered domains on success.
// runClientStartup starts the client and notifies registered services on success.
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
@@ -300,16 +319,16 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
return
}
// Mark client as started and collect domains to notify outside the lock.
// Mark client as started and collect services to notify outside the lock.
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
entry.started = true
}
var domainsToNotify []domainNotification
var toNotify []serviceNotification
if exists {
for dom, info := range entry.domains {
domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID})
for key, info := range entry.services {
toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID})
}
}
n.clientsMux.Unlock()
@@ -317,24 +336,24 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
if n.statusNotifier == nil {
return
}
for _, dn := range domainsToNotify {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil {
for _, sn := range toNotify {
if err := n.statusNotifier.NotifyStatus(ctx, accountID, sn.serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": dn.domain,
"account_id": accountID,
"service_key": sn.key,
}).WithError(err).Warn("failed to notify tunnel connection status")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": dn.domain,
"account_id": accountID,
"service_key": sn.key,
}).Info("notified management about tunnel connection")
}
}
}
// RemovePeer unregisters a domain from an account. The client is only stopped
// when no domains are using it anymore.
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error {
// RemovePeer unregisters a service from an account. The client is only stopped
// when no services are using it anymore.
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error {
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
@@ -344,74 +363,65 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
return nil
}
// Get domain info before deleting
domInfo, domainExists := entry.domains[d]
if !domainExists {
si, svcExists := entry.services[key]
if !svcExists {
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Debug("remove peer: domain not registered")
"account_id": accountID,
"service_key": key,
}).Debug("remove peer: service not registered")
return nil
}
delete(entry.domains, d)
// If there are still domains using this client, keep it running
if len(entry.domains) > 0 {
n.clientsMux.Unlock()
delete(entry.services, key)
stopClient := len(entry.services) == 0
var client *embed.Client
var transport, insecureTransport *http.Transport
if stopClient {
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
client = entry.client
transport = entry.transport
insecureTransport = entry.insecureTransport
delete(n.clients, accountID)
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
"remaining_domains": len(entry.domains),
}).Debug("unregistered domain, client still in use")
// Notify this domain as disconnected
if n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).WithError(err).Warn("failed to notify tunnel disconnection status")
}
}
return nil
"account_id": accountID,
"service_key": key,
"remaining_services": len(entry.services),
}).Debug("unregistered service, client still in use")
}
// No more domains using this client, stop it
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).Info("stopping client, no more domains")
client := entry.client
transport := entry.transport
insecureTransport := entry.insecureTransport
delete(n.clients, accountID)
n.clientsMux.Unlock()
// Notify disconnection before stopping
if n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).WithError(err).Warn("failed to notify tunnel disconnection status")
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
if stopClient {
transport.CloseIdleConnections()
insecureTransport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
}
}
transport.CloseIdleConnections()
insecureTransport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).WithError(err).Warn("failed to stop netbird client")
}
return nil
}
func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) {
if n.statusNotifier == nil {
return
}
if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, false); err != nil {
if s, ok := grpcstatus.FromError(err); ok && s.Code() == codes.NotFound {
n.logger.WithField("service_key", key).Debug("service already removed, skipping disconnect notification")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).WithError(err).Warn("failed to notify tunnel disconnection status")
}
}
}
// RoundTrip implements http.RoundTripper. It looks up the client for the account
// specified in the request context and uses it to dial the backend.
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -435,7 +445,7 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
}
n.clientsMux.RUnlock()
release, ok := entry.acquireInflight(req.URL.Host)
release, ok := entry.acquireInflight(backendKey(req.URL.Host))
defer release()
if !ok {
return nil, ErrTooManyInflight
@@ -496,16 +506,16 @@ func (n *NetBird) HasClient(accountID types.AccountID) bool {
return exists
}
// DomainCount returns the number of domains registered for the given account.
// ServiceCount returns the number of services registered for the given account.
// Returns 0 if the account has no client.
func (n *NetBird) DomainCount(accountID types.AccountID) int {
func (n *NetBird) ServiceCount(accountID types.AccountID) int {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
entry, exists := n.clients[accountID]
if !exists {
return 0
}
return len(entry.domains)
return len(entry.services)
}
// ClientCount returns the total number of active clients.
@@ -533,16 +543,16 @@ func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
result := make(map[types.AccountID]ClientDebugInfo)
for accountID, entry := range n.clients {
domains := make(domain.List, 0, len(entry.domains))
for d := range entry.domains {
domains = append(domains, d)
keys := make([]string, 0, len(entry.services))
for k := range entry.services {
keys = append(keys, string(k))
}
result[accountID] = ClientDebugInfo{
AccountID: accountID,
DomainCount: len(entry.domains),
Domains: domains,
HasClient: entry.client != nil,
CreatedAt: entry.createdAt,
AccountID: accountID,
ServiceCount: len(entry.services),
ServiceKeys: keys,
HasClient: entry.client != nil,
CreatedAt: entry.createdAt,
}
}
return result
@@ -581,6 +591,20 @@ func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.L
}
}
// dialWithTimeout wraps a DialContext function so that any dial timeout
// stored in the context (via types.WithDialTimeout) is applied only to
// the connection establishment phase, not the full request lifetime.
func dialWithTimeout(dial func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
if d, ok := types.DialTimeoutFromContext(ctx); ok {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, d)
defer cancel()
}
return dial(ctx, network, addr)
}
}
// WithAccountID adds the account ID to the context.
func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context {
return context.WithValue(ctx, accountIDContextKey{}, accountID)

View File

@@ -1,6 +1,7 @@
package roundtrip
import (
"context"
"crypto/rand"
"math/big"
"sync"
@@ -8,7 +9,6 @@ import (
"time"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
)
// Simple benchmark for comparison with AddPeer contention.
@@ -29,9 +29,9 @@ func BenchmarkHasClient(b *testing.B) {
target = id
}
nb.clients[id] = &clientEntry{
domains: map[domain.Domain]domainInfo{
domain.Domain(rand.Text()): {
serviceID: rand.Text(),
services: map[ServiceKey]serviceInfo{
ServiceKey(rand.Text()): {
serviceID: types.ServiceID(rand.Text()),
},
},
createdAt: time.Now(),
@@ -70,9 +70,9 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) {
target = id
}
nb.clients[id] = &clientEntry{
domains: map[domain.Domain]domainInfo{
domain.Domain(rand.Text()): {
serviceID: rand.Text(),
services: map[ServiceKey]serviceInfo{
ServiceKey(rand.Text()): {
serviceID: types.ServiceID(rand.Text()),
},
},
createdAt: time.Now(),
@@ -81,19 +81,22 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) {
}
// Launch workers that continuously call AddPeer with new random accountIDs.
ctx, cancel := context.WithCancel(b.Context())
var wg sync.WaitGroup
for range addPeerWorkers {
wg.Go(func() {
for {
if err := nb.AddPeer(b.Context(),
wg.Add(1)
go func() {
defer wg.Done()
for ctx.Err() == nil {
if err := nb.AddPeer(ctx,
types.AccountID(rand.Text()),
domain.Domain(rand.Text()),
ServiceKey(rand.Text()),
rand.Text(),
rand.Text()); err != nil {
b.Log(err)
types.ServiceID(rand.Text())); err != nil {
return
}
}
})
}()
}
// Benchmark calling HasClient during AddPeer contention.
@@ -104,4 +107,6 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) {
}
})
b.StopTimer()
cancel()
wg.Wait()
}

View File

@@ -11,7 +11,6 @@ import (
"google.golang.org/grpc"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -27,16 +26,15 @@ type mockStatusNotifier struct {
}
type statusCall struct {
accountID string
serviceID string
domain string
accountID types.AccountID
serviceID types.ServiceID
connected bool
}
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error {
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error {
m.mu.Lock()
defer m.mu.Unlock()
m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected})
m.statuses = append(m.statuses, statusCall{accountID, serviceID, connected})
return nil
}
@@ -62,36 +60,34 @@ func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
// Initially no client exists.
assert.False(t, nb.HasClient(accountID), "should not have client before AddPeer")
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
assert.Equal(t, 0, nb.ServiceCount(accountID), "service count should be 0")
// Add first domain - this should create a new client.
// Note: This will fail to actually connect since we use an invalid URL,
// but the client entry should still be created.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
// Add first service - this should create a new client.
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID), "should have client after AddPeer")
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
assert.Equal(t, 1, nb.ServiceCount(accountID), "service count should be 1")
}
func TestNetBird_AddPeer_ReuseClientForSameAccount(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add first domain.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
// Add first service.
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
require.NoError(t, err)
assert.Equal(t, 1, nb.DomainCount(accountID))
assert.Equal(t, 1, nb.ServiceCount(accountID))
// Add second domain for the same account - should reuse existing client.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
// Add second service for the same account - should reuse existing client.
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "setup-key-1", types.ServiceID("proxy-2"))
require.NoError(t, err)
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2 after adding second domain")
assert.Equal(t, 2, nb.ServiceCount(accountID), "service count should be 2 after adding second service")
// Add third domain.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
// Add third service.
err = nb.AddPeer(context.Background(), accountID, "domain3.test", "setup-key-1", types.ServiceID("proxy-3"))
require.NoError(t, err)
assert.Equal(t, 3, nb.DomainCount(accountID), "domain count should be 3 after adding third domain")
assert.Equal(t, 3, nb.ServiceCount(accountID), "service count should be 3 after adding third service")
// Still only one client.
assert.True(t, nb.HasClient(accountID))
@@ -102,64 +98,62 @@ func TestNetBird_AddPeer_SeparateClientsForDifferentAccounts(t *testing.T) {
account1 := types.AccountID("account-1")
account2 := types.AccountID("account-2")
// Add domain for account 1.
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
// Add service for account 1.
err := nb.AddPeer(context.Background(), account1, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
require.NoError(t, err)
// Add domain for account 2.
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "setup-key-2", "proxy-2")
// Add service for account 2.
err = nb.AddPeer(context.Background(), account2, "domain2.test", "setup-key-2", types.ServiceID("proxy-2"))
require.NoError(t, err)
// Both accounts should have their own clients.
assert.True(t, nb.HasClient(account1), "account1 should have client")
assert.True(t, nb.HasClient(account2), "account2 should have client")
assert.Equal(t, 1, nb.DomainCount(account1), "account1 domain count should be 1")
assert.Equal(t, 1, nb.DomainCount(account2), "account2 domain count should be 1")
assert.Equal(t, 1, nb.ServiceCount(account1), "account1 service count should be 1")
assert.Equal(t, 1, nb.ServiceCount(account2), "account2 service count should be 1")
}
func TestNetBird_RemovePeer_KeepsClientWhenDomainsRemain(t *testing.T) {
func TestNetBird_RemovePeer_KeepsClientWhenServicesRemain(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add multiple domains.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
// Add multiple services.
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "setup-key-1", types.ServiceID("proxy-2"))
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
err = nb.AddPeer(context.Background(), accountID, "domain3.test", "setup-key-1", types.ServiceID("proxy-3"))
require.NoError(t, err)
assert.Equal(t, 3, nb.DomainCount(accountID))
assert.Equal(t, 3, nb.ServiceCount(accountID))
// Remove one domain - client should remain.
// Remove one service - client should remain.
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID), "client should remain after removing one domain")
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2")
assert.True(t, nb.HasClient(accountID), "client should remain after removing one service")
assert.Equal(t, 2, nb.ServiceCount(accountID), "service count should be 2")
// Remove another domain - client should still remain.
// Remove another service - client should still remain.
err = nb.RemovePeer(context.Background(), accountID, "domain2.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID), "client should remain after removing second domain")
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
assert.True(t, nb.HasClient(accountID), "client should remain after removing second service")
assert.Equal(t, 1, nb.ServiceCount(accountID), "service count should be 1")
}
func TestNetBird_RemovePeer_RemovesClientWhenLastDomainRemoved(t *testing.T) {
func TestNetBird_RemovePeer_RemovesClientWhenLastServiceRemoved(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add single domain.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
// Add single service.
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID))
// Remove the only domain - client should be removed.
// Note: Stop() may fail since the client never actually connected,
// but the entry should still be removed from the map.
// Remove the only service - client should be removed.
_ = nb.RemovePeer(context.Background(), accountID, "domain1.test")
// After removing all domains, client should be gone.
assert.False(t, nb.HasClient(accountID), "client should be removed after removing last domain")
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
// After removing all services, client should be gone.
assert.False(t, nb.HasClient(accountID), "client should be removed after removing last service")
assert.Equal(t, 0, nb.ServiceCount(accountID), "service count should be 0")
}
func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) {
@@ -171,21 +165,21 @@ func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) {
assert.NoError(t, err, "removing from non-existent account should not error")
}
func TestNetBird_RemovePeer_NonExistentDomainIsNoop(t *testing.T) {
func TestNetBird_RemovePeer_NonExistentServiceIsNoop(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add one domain.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
// Add one service.
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
require.NoError(t, err)
// Remove non-existent domain - should not affect existing domain.
err = nb.RemovePeer(context.Background(), accountID, domain.Domain("nonexistent.test"))
// Remove non-existent service - should not affect existing service.
err = nb.RemovePeer(context.Background(), accountID, "nonexistent.test")
require.NoError(t, err)
// Original domain should still be registered.
// Original service should still be registered.
assert.True(t, nb.HasClient(accountID))
assert.Equal(t, 1, nb.DomainCount(accountID), "original domain should remain")
assert.Equal(t, 1, nb.ServiceCount(accountID), "original service should remain")
}
func TestWithAccountID_AndAccountIDFromContext(t *testing.T) {
@@ -216,19 +210,17 @@ func TestNetBird_StopAll_StopsAllClients(t *testing.T) {
account2 := types.AccountID("account-2")
account3 := types.AccountID("account-3")
// Add domains for multiple accounts.
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "key-1", "proxy-1")
// Add services for multiple accounts.
err := nb.AddPeer(context.Background(), account1, "domain1.test", "key-1", types.ServiceID("proxy-1"))
require.NoError(t, err)
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "key-2", "proxy-2")
err = nb.AddPeer(context.Background(), account2, "domain2.test", "key-2", types.ServiceID("proxy-2"))
require.NoError(t, err)
err = nb.AddPeer(context.Background(), account3, domain.Domain("domain3.test"), "key-3", "proxy-3")
err = nb.AddPeer(context.Background(), account3, "domain3.test", "key-3", types.ServiceID("proxy-3"))
require.NoError(t, err)
assert.Equal(t, 3, nb.ClientCount(), "should have 3 clients")
// Stop all clients.
// Note: StopAll may return errors since clients never actually connected,
// but the clients should still be removed from the map.
_ = nb.StopAll(context.Background())
assert.Equal(t, 0, nb.ClientCount(), "should have 0 clients after StopAll")
@@ -243,18 +235,18 @@ func TestNetBird_ClientCount(t *testing.T) {
assert.Equal(t, 0, nb.ClientCount(), "should start with 0 clients")
// Add clients for different accounts.
err := nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1.test"), "key-1", "proxy-1")
err := nb.AddPeer(context.Background(), types.AccountID("account-1"), "domain1.test", "key-1", types.ServiceID("proxy-1"))
require.NoError(t, err)
assert.Equal(t, 1, nb.ClientCount())
err = nb.AddPeer(context.Background(), types.AccountID("account-2"), domain.Domain("domain2.test"), "key-2", "proxy-2")
err = nb.AddPeer(context.Background(), types.AccountID("account-2"), "domain2.test", "key-2", types.ServiceID("proxy-2"))
require.NoError(t, err)
assert.Equal(t, 2, nb.ClientCount())
// Adding domain to existing account should not increase count.
err = nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1b.test"), "key-1", "proxy-1b")
// Adding service to existing account should not increase count.
err = nb.AddPeer(context.Background(), types.AccountID("account-1"), "domain1b.test", "key-1", types.ServiceID("proxy-1b"))
require.NoError(t, err)
assert.Equal(t, 2, nb.ClientCount(), "adding domain to existing account should not increase client count")
assert.Equal(t, 2, nb.ClientCount(), "adding service to existing account should not increase client count")
}
func TestNetBird_RoundTrip_RequiresAccountIDInContext(t *testing.T) {
@@ -293,8 +285,8 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
}, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1")
// Add first domain — creates a new client entry.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
// Add first service — creates a new client entry.
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "key-1", types.ServiceID("svc-1"))
require.NoError(t, err)
// Manually mark client as started to simulate background startup completing.
@@ -302,15 +294,14 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
nb.clients[accountID].started = true
nb.clientsMux.Unlock()
// Add second domain — should notify immediately since client is already started.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
// Add second service — should notify immediately since client is already started.
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2"))
require.NoError(t, err)
calls := notifier.calls()
require.Len(t, calls, 1)
assert.Equal(t, string(accountID), calls[0].accountID)
assert.Equal(t, "svc-2", calls[0].serviceID)
assert.Equal(t, "domain2.test", calls[0].domain)
assert.Equal(t, accountID, calls[0].accountID)
assert.Equal(t, types.ServiceID("svc-2"), calls[0].serviceID)
assert.True(t, calls[0].connected)
}
@@ -323,18 +314,18 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
}, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1")
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "key-1", types.ServiceID("svc-1"))
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2"))
require.NoError(t, err)
// Remove one domain — client stays, but disconnection notification fires.
// Remove one service — client stays, but disconnection notification fires.
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID))
calls := notifier.calls()
require.Len(t, calls, 1)
assert.Equal(t, "domain1.test", calls[0].domain)
assert.Equal(t, types.ServiceID("svc-1"), calls[0].serviceID)
assert.False(t, calls[0].connected)
}