mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -14,11 +15,12 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
grpcstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
@@ -26,7 +28,22 @@ import (
|
||||
const deviceNamePrefix = "ingress-proxy-"
|
||||
|
||||
// backendKey identifies a backend by its host:port from the target URL.
|
||||
type backendKey = string
|
||||
type backendKey string
|
||||
|
||||
// ServiceKey uniquely identifies a service (HTTP reverse proxy or L4 service)
|
||||
// that holds a reference to an embedded NetBird client. Callers should use the
|
||||
// DomainServiceKey and L4ServiceKey constructors to avoid namespace collisions.
|
||||
type ServiceKey string
|
||||
|
||||
// DomainServiceKey returns a ServiceKey for an HTTP/TLS domain-based service.
|
||||
func DomainServiceKey(domain string) ServiceKey {
|
||||
return ServiceKey("domain:" + domain)
|
||||
}
|
||||
|
||||
// L4ServiceKey returns a ServiceKey for an L4 service (TCP/UDP).
|
||||
func L4ServiceKey(id types.ServiceID) ServiceKey {
|
||||
return ServiceKey("l4:" + id)
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrNoAccountID is returned when a request context is missing the account ID.
|
||||
@@ -39,24 +56,24 @@ var (
|
||||
ErrTooManyInflight = errors.New("too many in-flight requests")
|
||||
)
|
||||
|
||||
// domainInfo holds metadata about a registered domain.
|
||||
type domainInfo struct {
|
||||
serviceID string
|
||||
// serviceInfo holds metadata about a registered service.
|
||||
type serviceInfo struct {
|
||||
serviceID types.ServiceID
|
||||
}
|
||||
|
||||
type domainNotification struct {
|
||||
domain domain.Domain
|
||||
serviceID string
|
||||
type serviceNotification struct {
|
||||
key ServiceKey
|
||||
serviceID types.ServiceID
|
||||
}
|
||||
|
||||
// clientEntry holds an embedded NetBird client and tracks which domains use it.
|
||||
// clientEntry holds an embedded NetBird client and tracks which services use it.
|
||||
type clientEntry struct {
|
||||
client *embed.Client
|
||||
transport *http.Transport
|
||||
// insecureTransport is a clone of transport with TLS verification disabled,
|
||||
// used when per-target skip_tls_verify is set.
|
||||
insecureTransport *http.Transport
|
||||
domains map[domain.Domain]domainInfo
|
||||
services map[ServiceKey]serviceInfo
|
||||
createdAt time.Time
|
||||
started bool
|
||||
// Per-backend in-flight limiting keyed by target host:port.
|
||||
@@ -93,12 +110,12 @@ func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bo
|
||||
// ClientConfig holds configuration for the embedded NetBird client.
|
||||
type ClientConfig struct {
|
||||
MgmtAddr string
|
||||
WGPort int
|
||||
WGPort uint16
|
||||
PreSharedKey string
|
||||
}
|
||||
|
||||
type statusNotifier interface {
|
||||
NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error
|
||||
NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error
|
||||
}
|
||||
|
||||
type managementClient interface {
|
||||
@@ -107,7 +124,7 @@ type managementClient interface {
|
||||
|
||||
// NetBird provides an http.RoundTripper implementation
|
||||
// backed by underlying NetBird connections.
|
||||
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
|
||||
// Clients are keyed by AccountID, allowing multiple services to share the same connection.
|
||||
type NetBird struct {
|
||||
proxyID string
|
||||
proxyAddr string
|
||||
@@ -124,11 +141,11 @@ type NetBird struct {
|
||||
|
||||
// ClientDebugInfo contains debug information about a client.
|
||||
type ClientDebugInfo struct {
|
||||
AccountID types.AccountID
|
||||
DomainCount int
|
||||
Domains domain.List
|
||||
HasClient bool
|
||||
CreatedAt time.Time
|
||||
AccountID types.AccountID
|
||||
ServiceCount int
|
||||
ServiceKeys []string
|
||||
HasClient bool
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// accountIDContextKey is the context key for storing the account ID.
|
||||
@@ -137,37 +154,37 @@ type accountIDContextKey struct{}
|
||||
// skipTLSVerifyContextKey is the context key for requesting insecure TLS.
|
||||
type skipTLSVerifyContextKey struct{}
|
||||
|
||||
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
|
||||
// AddPeer registers a service for an account. If the account doesn't have a client yet,
|
||||
// one is created by authenticating with the management server using the provided token.
|
||||
// Multiple domains can share the same client.
|
||||
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) error {
|
||||
// Multiple services can share the same client.
|
||||
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
|
||||
si := serviceInfo{serviceID: serviceID}
|
||||
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
// Client already exists for this account, just register the domain
|
||||
entry.domains[d] = domainInfo{serviceID: serviceID}
|
||||
entry.services[key] = si
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Debug("registered domain with existing client")
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).Debug("registered service with existing client")
|
||||
|
||||
// If client is already started, notify this domain as connected immediately
|
||||
if started && n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), serviceID, string(d), true); err != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).WithError(err).Warn("failed to notify status for existing client")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID)
|
||||
entry, err := n.createClientEntry(ctx, accountID, key, authToken, si)
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return err
|
||||
@@ -177,8 +194,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background; if this fails we will
|
||||
@@ -190,7 +207,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
|
||||
|
||||
// createClientEntry generates a WireGuard keypair, authenticates with management,
|
||||
// and creates an embedded NetBird client. Must be called with clientsMux held.
|
||||
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) {
|
||||
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
|
||||
serviceID := si.serviceID
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_id": serviceID,
|
||||
@@ -209,7 +227,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
}).Debug("authenticating new proxy peer with management")
|
||||
|
||||
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
|
||||
ServiceId: serviceID,
|
||||
ServiceId: string(serviceID),
|
||||
AccountId: string(accountID),
|
||||
Token: authToken,
|
||||
WireguardPublicKey: publicKey.String(),
|
||||
@@ -240,13 +258,14 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
|
||||
// Create embedded NetBird client with the generated private key.
|
||||
// The peer has already been created via CreateProxyPeer RPC with the public key.
|
||||
wgPort := int(n.clientCfg.WGPort)
|
||||
client, err := embed.New(embed.Options{
|
||||
DeviceName: deviceNamePrefix + n.proxyID,
|
||||
ManagementURL: n.clientCfg.MgmtAddr,
|
||||
PrivateKey: privateKey.String(),
|
||||
LogLevel: log.WarnLevel.String(),
|
||||
BlockInbound: true,
|
||||
WireguardPort: &n.clientCfg.WGPort,
|
||||
WireguardPort: &wgPort,
|
||||
PreSharedKey: n.clientCfg.PreSharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -257,7 +276,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
// the client's HTTPClient to avoid issues with request validation that do
|
||||
// not work with reverse proxied requests.
|
||||
transport := &http.Transport{
|
||||
DialContext: client.DialContext,
|
||||
DialContext: dialWithTimeout(client.DialContext),
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: n.transportCfg.maxIdleConns,
|
||||
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
|
||||
@@ -276,7 +295,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
|
||||
return &clientEntry{
|
||||
client: client,
|
||||
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
|
||||
services: map[ServiceKey]serviceInfo{key: si},
|
||||
transport: transport,
|
||||
insecureTransport: insecureTransport,
|
||||
createdAt: time.Now(),
|
||||
@@ -286,7 +305,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
}, nil
|
||||
}
|
||||
|
||||
// runClientStartup starts the client and notifies registered domains on success.
|
||||
// runClientStartup starts the client and notifies registered services on success.
|
||||
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
@@ -300,16 +319,16 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
|
||||
return
|
||||
}
|
||||
|
||||
// Mark client as started and collect domains to notify outside the lock.
|
||||
// Mark client as started and collect services to notify outside the lock.
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.started = true
|
||||
}
|
||||
var domainsToNotify []domainNotification
|
||||
var toNotify []serviceNotification
|
||||
if exists {
|
||||
for dom, info := range entry.domains {
|
||||
domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID})
|
||||
for key, info := range entry.services {
|
||||
toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID})
|
||||
}
|
||||
}
|
||||
n.clientsMux.Unlock()
|
||||
@@ -317,24 +336,24 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
|
||||
if n.statusNotifier == nil {
|
||||
return
|
||||
}
|
||||
for _, dn := range domainsToNotify {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil {
|
||||
for _, sn := range toNotify {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, sn.serviceID, true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": dn.domain,
|
||||
"account_id": accountID,
|
||||
"service_key": sn.key,
|
||||
}).WithError(err).Warn("failed to notify tunnel connection status")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": dn.domain,
|
||||
"account_id": accountID,
|
||||
"service_key": sn.key,
|
||||
}).Info("notified management about tunnel connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RemovePeer unregisters a domain from an account. The client is only stopped
|
||||
// when no domains are using it anymore.
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error {
|
||||
// RemovePeer unregisters a service from an account. The client is only stopped
|
||||
// when no services are using it anymore.
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error {
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
@@ -344,74 +363,65 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get domain info before deleting
|
||||
domInfo, domainExists := entry.domains[d]
|
||||
if !domainExists {
|
||||
si, svcExists := entry.services[key]
|
||||
if !svcExists {
|
||||
n.clientsMux.Unlock()
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Debug("remove peer: domain not registered")
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).Debug("remove peer: service not registered")
|
||||
return nil
|
||||
}
|
||||
|
||||
delete(entry.domains, d)
|
||||
|
||||
// If there are still domains using this client, keep it running
|
||||
if len(entry.domains) > 0 {
|
||||
n.clientsMux.Unlock()
|
||||
delete(entry.services, key)
|
||||
|
||||
stopClient := len(entry.services) == 0
|
||||
var client *embed.Client
|
||||
var transport, insecureTransport *http.Transport
|
||||
if stopClient {
|
||||
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
|
||||
client = entry.client
|
||||
transport = entry.transport
|
||||
insecureTransport = entry.insecureTransport
|
||||
delete(n.clients, accountID)
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
"remaining_domains": len(entry.domains),
|
||||
}).Debug("unregistered domain, client still in use")
|
||||
|
||||
// Notify this domain as disconnected
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
"remaining_services": len(entry.services),
|
||||
}).Debug("unregistered service, client still in use")
|
||||
}
|
||||
|
||||
// No more domains using this client, stop it
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).Info("stopping client, no more domains")
|
||||
|
||||
client := entry.client
|
||||
transport := entry.transport
|
||||
insecureTransport := entry.insecureTransport
|
||||
delete(n.clients, accountID)
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
// Notify disconnection before stopping
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
||||
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
|
||||
|
||||
if stopClient {
|
||||
transport.CloseIdleConnections()
|
||||
insecureTransport.CloseIdleConnections()
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
}
|
||||
|
||||
transport.CloseIdleConnections()
|
||||
insecureTransport.CloseIdleConnections()
|
||||
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) {
|
||||
if n.statusNotifier == nil {
|
||||
return
|
||||
}
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, false); err != nil {
|
||||
if s, ok := grpcstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
n.logger.WithField("service_key", key).Debug("service already removed, skipping disconnect notification")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper. It looks up the client for the account
|
||||
// specified in the request context and uses it to dial the backend.
|
||||
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
@@ -435,7 +445,7 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
n.clientsMux.RUnlock()
|
||||
|
||||
release, ok := entry.acquireInflight(req.URL.Host)
|
||||
release, ok := entry.acquireInflight(backendKey(req.URL.Host))
|
||||
defer release()
|
||||
if !ok {
|
||||
return nil, ErrTooManyInflight
|
||||
@@ -496,16 +506,16 @@ func (n *NetBird) HasClient(accountID types.AccountID) bool {
|
||||
return exists
|
||||
}
|
||||
|
||||
// DomainCount returns the number of domains registered for the given account.
|
||||
// ServiceCount returns the number of services registered for the given account.
|
||||
// Returns 0 if the account has no client.
|
||||
func (n *NetBird) DomainCount(accountID types.AccountID) int {
|
||||
func (n *NetBird) ServiceCount(accountID types.AccountID) int {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
return len(entry.domains)
|
||||
return len(entry.services)
|
||||
}
|
||||
|
||||
// ClientCount returns the total number of active clients.
|
||||
@@ -533,16 +543,16 @@ func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
|
||||
|
||||
result := make(map[types.AccountID]ClientDebugInfo)
|
||||
for accountID, entry := range n.clients {
|
||||
domains := make(domain.List, 0, len(entry.domains))
|
||||
for d := range entry.domains {
|
||||
domains = append(domains, d)
|
||||
keys := make([]string, 0, len(entry.services))
|
||||
for k := range entry.services {
|
||||
keys = append(keys, string(k))
|
||||
}
|
||||
result[accountID] = ClientDebugInfo{
|
||||
AccountID: accountID,
|
||||
DomainCount: len(entry.domains),
|
||||
Domains: domains,
|
||||
HasClient: entry.client != nil,
|
||||
CreatedAt: entry.createdAt,
|
||||
AccountID: accountID,
|
||||
ServiceCount: len(entry.services),
|
||||
ServiceKeys: keys,
|
||||
HasClient: entry.client != nil,
|
||||
CreatedAt: entry.createdAt,
|
||||
}
|
||||
}
|
||||
return result
|
||||
@@ -581,6 +591,20 @@ func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.L
|
||||
}
|
||||
}
|
||||
|
||||
// dialWithTimeout wraps a DialContext function so that any dial timeout
|
||||
// stored in the context (via types.WithDialTimeout) is applied only to
|
||||
// the connection establishment phase, not the full request lifetime.
|
||||
func dialWithTimeout(dial func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if d, ok := types.DialTimeoutFromContext(ctx); ok {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, d)
|
||||
defer cancel()
|
||||
}
|
||||
return dial(ctx, network, addr)
|
||||
}
|
||||
}
|
||||
|
||||
// WithAccountID adds the account ID to the context.
|
||||
func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context {
|
||||
return context.WithValue(ctx, accountIDContextKey{}, accountID)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"math/big"
|
||||
"sync"
|
||||
@@ -8,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// Simple benchmark for comparison with AddPeer contention.
|
||||
@@ -29,9 +29,9 @@ func BenchmarkHasClient(b *testing.B) {
|
||||
target = id
|
||||
}
|
||||
nb.clients[id] = &clientEntry{
|
||||
domains: map[domain.Domain]domainInfo{
|
||||
domain.Domain(rand.Text()): {
|
||||
serviceID: rand.Text(),
|
||||
services: map[ServiceKey]serviceInfo{
|
||||
ServiceKey(rand.Text()): {
|
||||
serviceID: types.ServiceID(rand.Text()),
|
||||
},
|
||||
},
|
||||
createdAt: time.Now(),
|
||||
@@ -70,9 +70,9 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) {
|
||||
target = id
|
||||
}
|
||||
nb.clients[id] = &clientEntry{
|
||||
domains: map[domain.Domain]domainInfo{
|
||||
domain.Domain(rand.Text()): {
|
||||
serviceID: rand.Text(),
|
||||
services: map[ServiceKey]serviceInfo{
|
||||
ServiceKey(rand.Text()): {
|
||||
serviceID: types.ServiceID(rand.Text()),
|
||||
},
|
||||
},
|
||||
createdAt: time.Now(),
|
||||
@@ -81,19 +81,22 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) {
|
||||
}
|
||||
|
||||
// Launch workers that continuously call AddPeer with new random accountIDs.
|
||||
ctx, cancel := context.WithCancel(b.Context())
|
||||
var wg sync.WaitGroup
|
||||
for range addPeerWorkers {
|
||||
wg.Go(func() {
|
||||
for {
|
||||
if err := nb.AddPeer(b.Context(),
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for ctx.Err() == nil {
|
||||
if err := nb.AddPeer(ctx,
|
||||
types.AccountID(rand.Text()),
|
||||
domain.Domain(rand.Text()),
|
||||
ServiceKey(rand.Text()),
|
||||
rand.Text(),
|
||||
rand.Text()); err != nil {
|
||||
b.Log(err)
|
||||
types.ServiceID(rand.Text())); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
// Benchmark calling HasClient during AddPeer contention.
|
||||
@@ -104,4 +107,6 @@ func BenchmarkHasClientDuringAddPeer(b *testing.B) {
|
||||
}
|
||||
})
|
||||
b.StopTimer()
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -27,16 +26,15 @@ type mockStatusNotifier struct {
|
||||
}
|
||||
|
||||
type statusCall struct {
|
||||
accountID string
|
||||
serviceID string
|
||||
domain string
|
||||
accountID types.AccountID
|
||||
serviceID types.ServiceID
|
||||
connected bool
|
||||
}
|
||||
|
||||
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error {
|
||||
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected})
|
||||
m.statuses = append(m.statuses, statusCall{accountID, serviceID, connected})
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -62,36 +60,34 @@ func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
|
||||
|
||||
// Initially no client exists.
|
||||
assert.False(t, nb.HasClient(accountID), "should not have client before AddPeer")
|
||||
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
|
||||
assert.Equal(t, 0, nb.ServiceCount(accountID), "service count should be 0")
|
||||
|
||||
// Add first domain - this should create a new client.
|
||||
// Note: This will fail to actually connect since we use an invalid URL,
|
||||
// but the client entry should still be created.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
// Add first service - this should create a new client.
|
||||
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, nb.HasClient(accountID), "should have client after AddPeer")
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
|
||||
assert.Equal(t, 1, nb.ServiceCount(accountID), "service count should be 1")
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_ReuseClientForSameAccount(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add first domain.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
// Add first service.
|
||||
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID))
|
||||
assert.Equal(t, 1, nb.ServiceCount(accountID))
|
||||
|
||||
// Add second domain for the same account - should reuse existing client.
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
|
||||
// Add second service for the same account - should reuse existing client.
|
||||
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "setup-key-1", types.ServiceID("proxy-2"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2 after adding second domain")
|
||||
assert.Equal(t, 2, nb.ServiceCount(accountID), "service count should be 2 after adding second service")
|
||||
|
||||
// Add third domain.
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
|
||||
// Add third service.
|
||||
err = nb.AddPeer(context.Background(), accountID, "domain3.test", "setup-key-1", types.ServiceID("proxy-3"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, nb.DomainCount(accountID), "domain count should be 3 after adding third domain")
|
||||
assert.Equal(t, 3, nb.ServiceCount(accountID), "service count should be 3 after adding third service")
|
||||
|
||||
// Still only one client.
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
@@ -102,64 +98,62 @@ func TestNetBird_AddPeer_SeparateClientsForDifferentAccounts(t *testing.T) {
|
||||
account1 := types.AccountID("account-1")
|
||||
account2 := types.AccountID("account-2")
|
||||
|
||||
// Add domain for account 1.
|
||||
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
// Add service for account 1.
|
||||
err := nb.AddPeer(context.Background(), account1, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add domain for account 2.
|
||||
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "setup-key-2", "proxy-2")
|
||||
// Add service for account 2.
|
||||
err = nb.AddPeer(context.Background(), account2, "domain2.test", "setup-key-2", types.ServiceID("proxy-2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Both accounts should have their own clients.
|
||||
assert.True(t, nb.HasClient(account1), "account1 should have client")
|
||||
assert.True(t, nb.HasClient(account2), "account2 should have client")
|
||||
assert.Equal(t, 1, nb.DomainCount(account1), "account1 domain count should be 1")
|
||||
assert.Equal(t, 1, nb.DomainCount(account2), "account2 domain count should be 1")
|
||||
assert.Equal(t, 1, nb.ServiceCount(account1), "account1 service count should be 1")
|
||||
assert.Equal(t, 1, nb.ServiceCount(account2), "account2 service count should be 1")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_KeepsClientWhenDomainsRemain(t *testing.T) {
|
||||
func TestNetBird_RemovePeer_KeepsClientWhenServicesRemain(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add multiple domains.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
// Add multiple services.
|
||||
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
|
||||
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "setup-key-1", types.ServiceID("proxy-2"))
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
|
||||
err = nb.AddPeer(context.Background(), accountID, "domain3.test", "setup-key-1", types.ServiceID("proxy-3"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, nb.DomainCount(accountID))
|
||||
assert.Equal(t, 3, nb.ServiceCount(accountID))
|
||||
|
||||
// Remove one domain - client should remain.
|
||||
// Remove one service - client should remain.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID), "client should remain after removing one domain")
|
||||
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2")
|
||||
assert.True(t, nb.HasClient(accountID), "client should remain after removing one service")
|
||||
assert.Equal(t, 2, nb.ServiceCount(accountID), "service count should be 2")
|
||||
|
||||
// Remove another domain - client should still remain.
|
||||
// Remove another service - client should still remain.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "domain2.test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID), "client should remain after removing second domain")
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
|
||||
assert.True(t, nb.HasClient(accountID), "client should remain after removing second service")
|
||||
assert.Equal(t, 1, nb.ServiceCount(accountID), "service count should be 1")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_RemovesClientWhenLastDomainRemoved(t *testing.T) {
|
||||
func TestNetBird_RemovePeer_RemovesClientWhenLastServiceRemoved(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add single domain.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
// Add single service.
|
||||
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
|
||||
// Remove the only domain - client should be removed.
|
||||
// Note: Stop() may fail since the client never actually connected,
|
||||
// but the entry should still be removed from the map.
|
||||
// Remove the only service - client should be removed.
|
||||
_ = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
|
||||
// After removing all domains, client should be gone.
|
||||
assert.False(t, nb.HasClient(accountID), "client should be removed after removing last domain")
|
||||
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
|
||||
// After removing all services, client should be gone.
|
||||
assert.False(t, nb.HasClient(accountID), "client should be removed after removing last service")
|
||||
assert.Equal(t, 0, nb.ServiceCount(accountID), "service count should be 0")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) {
|
||||
@@ -171,21 +165,21 @@ func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) {
|
||||
assert.NoError(t, err, "removing from non-existent account should not error")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_NonExistentDomainIsNoop(t *testing.T) {
|
||||
func TestNetBird_RemovePeer_NonExistentServiceIsNoop(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add one domain.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
// Add one service.
|
||||
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "setup-key-1", types.ServiceID("proxy-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove non-existent domain - should not affect existing domain.
|
||||
err = nb.RemovePeer(context.Background(), accountID, domain.Domain("nonexistent.test"))
|
||||
// Remove non-existent service - should not affect existing service.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "nonexistent.test")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Original domain should still be registered.
|
||||
// Original service should still be registered.
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID), "original domain should remain")
|
||||
assert.Equal(t, 1, nb.ServiceCount(accountID), "original service should remain")
|
||||
}
|
||||
|
||||
func TestWithAccountID_AndAccountIDFromContext(t *testing.T) {
|
||||
@@ -216,19 +210,17 @@ func TestNetBird_StopAll_StopsAllClients(t *testing.T) {
|
||||
account2 := types.AccountID("account-2")
|
||||
account3 := types.AccountID("account-3")
|
||||
|
||||
// Add domains for multiple accounts.
|
||||
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "key-1", "proxy-1")
|
||||
// Add services for multiple accounts.
|
||||
err := nb.AddPeer(context.Background(), account1, "domain1.test", "key-1", types.ServiceID("proxy-1"))
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "key-2", "proxy-2")
|
||||
err = nb.AddPeer(context.Background(), account2, "domain2.test", "key-2", types.ServiceID("proxy-2"))
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), account3, domain.Domain("domain3.test"), "key-3", "proxy-3")
|
||||
err = nb.AddPeer(context.Background(), account3, "domain3.test", "key-3", types.ServiceID("proxy-3"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 3, nb.ClientCount(), "should have 3 clients")
|
||||
|
||||
// Stop all clients.
|
||||
// Note: StopAll may return errors since clients never actually connected,
|
||||
// but the clients should still be removed from the map.
|
||||
_ = nb.StopAll(context.Background())
|
||||
|
||||
assert.Equal(t, 0, nb.ClientCount(), "should have 0 clients after StopAll")
|
||||
@@ -243,18 +235,18 @@ func TestNetBird_ClientCount(t *testing.T) {
|
||||
assert.Equal(t, 0, nb.ClientCount(), "should start with 0 clients")
|
||||
|
||||
// Add clients for different accounts.
|
||||
err := nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1.test"), "key-1", "proxy-1")
|
||||
err := nb.AddPeer(context.Background(), types.AccountID("account-1"), "domain1.test", "key-1", types.ServiceID("proxy-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, nb.ClientCount())
|
||||
|
||||
err = nb.AddPeer(context.Background(), types.AccountID("account-2"), domain.Domain("domain2.test"), "key-2", "proxy-2")
|
||||
err = nb.AddPeer(context.Background(), types.AccountID("account-2"), "domain2.test", "key-2", types.ServiceID("proxy-2"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, nb.ClientCount())
|
||||
|
||||
// Adding domain to existing account should not increase count.
|
||||
err = nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1b.test"), "key-1", "proxy-1b")
|
||||
// Adding service to existing account should not increase count.
|
||||
err = nb.AddPeer(context.Background(), types.AccountID("account-1"), "domain1b.test", "key-1", types.ServiceID("proxy-1b"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, nb.ClientCount(), "adding domain to existing account should not increase client count")
|
||||
assert.Equal(t, 2, nb.ClientCount(), "adding service to existing account should not increase client count")
|
||||
}
|
||||
|
||||
func TestNetBird_RoundTrip_RequiresAccountIDInContext(t *testing.T) {
|
||||
@@ -293,8 +285,8 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
}, nil, notifier, &mockMgmtClient{})
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add first domain — creates a new client entry.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
|
||||
// Add first service — creates a new client entry.
|
||||
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "key-1", types.ServiceID("svc-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually mark client as started to simulate background startup completing.
|
||||
@@ -302,15 +294,14 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
nb.clients[accountID].started = true
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
// Add second domain — should notify immediately since client is already started.
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
|
||||
// Add second service — should notify immediately since client is already started.
|
||||
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
calls := notifier.calls()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, string(accountID), calls[0].accountID)
|
||||
assert.Equal(t, "svc-2", calls[0].serviceID)
|
||||
assert.Equal(t, "domain2.test", calls[0].domain)
|
||||
assert.Equal(t, accountID, calls[0].accountID)
|
||||
assert.Equal(t, types.ServiceID("svc-2"), calls[0].serviceID)
|
||||
assert.True(t, calls[0].connected)
|
||||
}
|
||||
|
||||
@@ -323,18 +314,18 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
||||
}, nil, notifier, &mockMgmtClient{})
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
|
||||
err := nb.AddPeer(context.Background(), accountID, "domain1.test", "key-1", types.ServiceID("svc-1"))
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
|
||||
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove one domain — client stays, but disconnection notification fires.
|
||||
// Remove one service — client stays, but disconnection notification fires.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
|
||||
calls := notifier.calls()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "domain1.test", calls[0].domain)
|
||||
assert.Equal(t, types.ServiceID("svc-1"), calls[0].serviceID)
|
||||
assert.False(t, calls[0].connected)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user