mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-24 03:06:38 +00:00
Use a 1:1 mapping of netbird client to netbird account
- Add debug endpoint for monitoring netbird clients - Add types package with AccountID type - Refactor netbird roundtrip to key clients by AccountID - Multiple domains can share the same client per account - Add status notifier for tunnel connection updates - Add OIDC flags to CLI - Add tests for netbird client management
This commit is contained in:
@@ -4,18 +4,39 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const deviceNamePrefix = "ingress-"
|
||||
const deviceNamePrefix = "ingress-proxy-"
|
||||
|
||||
// ErrNoAccountID is returned when a request context is missing the account ID.
|
||||
var ErrNoAccountID = errors.New("no account ID in request context")
|
||||
|
||||
// domainInfo holds metadata about a registered domain.
|
||||
type domainInfo struct {
|
||||
reverseProxyID string
|
||||
}
|
||||
|
||||
// clientEntry holds an embedded NetBird client and tracks which domains use it.
|
||||
type clientEntry struct {
|
||||
client *embed.Client
|
||||
transport *http.Transport
|
||||
domains map[domain.Domain]domainInfo
|
||||
createdAt time.Time
|
||||
started bool
|
||||
}
|
||||
|
||||
type statusNotifier interface {
|
||||
NotifyStatus(ctx context.Context, accountID, reverseProxyID, domain string, connected bool) error
|
||||
@@ -23,147 +44,389 @@ type statusNotifier interface {
|
||||
|
||||
// NetBird provides an http.RoundTripper implementation
|
||||
// backed by underlying NetBird connections.
|
||||
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
|
||||
type NetBird struct {
|
||||
mgmtAddr string
|
||||
proxyID string
|
||||
logger *log.Logger
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
clients map[string]*embed.Client
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
clients map[types.AccountID]*clientEntry
|
||||
initLogOnce sync.Once
|
||||
statusNotifier statusNotifier
|
||||
}
|
||||
|
||||
func NewNetBird(mgmtAddr string, logger *log.Logger, notifier statusNotifier) *NetBird {
|
||||
// NewNetBird creates a new NetBird transport.
|
||||
func NewNetBird(mgmtAddr, proxyID string, logger *log.Logger, notifier statusNotifier) *NetBird {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
return &NetBird{
|
||||
mgmtAddr: mgmtAddr,
|
||||
proxyID: proxyID,
|
||||
logger: logger,
|
||||
clients: make(map[string]*embed.Client),
|
||||
clients: make(map[types.AccountID]*clientEntry),
|
||||
statusNotifier: notifier,
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NetBird) AddPeer(ctx context.Context, domain, key, accountID, reverseProxyID string) error {
|
||||
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
|
||||
// one is created using the provided setup key. Multiple domains can share the same client.
|
||||
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, key, reverseProxyID string) error {
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
// Client already exists for this account, just register the domain
|
||||
entry.domains[d] = domainInfo{reverseProxyID: reverseProxyID}
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Debug("registered domain with existing client")
|
||||
|
||||
// If client is already started, notify this domain as connected immediately
|
||||
if started && n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), reverseProxyID, string(d), true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).WithError(err).Warn("failed to notify status for existing client")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
n.initLogOnce.Do(func() {
|
||||
if err := util.InitLog(log.WarnLevel.String(), util.LogConsole); err != nil {
|
||||
n.logger.WithField("account_id", accountID).Warnf("failed to initialize embedded client logging: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
wgPort := 0
|
||||
client, err := embed.New(embed.Options{
|
||||
DeviceName: deviceNamePrefix + domain,
|
||||
DeviceName: deviceNamePrefix + n.proxyID,
|
||||
ManagementURL: n.mgmtAddr,
|
||||
SetupKey: key,
|
||||
LogOutput: io.Discard,
|
||||
LogLevel: log.WarnLevel.String(),
|
||||
BlockInbound: true,
|
||||
WireguardPort: &wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return fmt.Errorf("create netbird client: %w", err)
|
||||
}
|
||||
|
||||
// Create a transport using the client dialer. We do this instead of using
|
||||
// the client's HTTPClient to avoid issues with request validation that do
|
||||
// not work with reverse proxied requests.
|
||||
entry = &clientEntry{
|
||||
client: client,
|
||||
domains: map[domain.Domain]domainInfo{d: {reverseProxyID: reverseProxyID}},
|
||||
transport: &http.Transport{
|
||||
DialContext: client.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
},
|
||||
createdAt: time.Now(),
|
||||
started: false,
|
||||
}
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background, if this fails
|
||||
// then it is not ideal, but it isn't the end of the world because
|
||||
// we will try to start the client again before we use it.
|
||||
go func() {
|
||||
startCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
err = client.Start(startCtx)
|
||||
switch {
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
n.logger.Debug("netbird client timed out")
|
||||
// This is not ideal, but we will try again later.
|
||||
return
|
||||
case err != nil:
|
||||
n.logger.WithField("domain", domain).WithError(err).Error("Unable to start netbird client, will try again later.")
|
||||
|
||||
if err := client.Start(startCtx); err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).Debug("netbird client start timed out, will retry on first request")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).WithError(err).Error("failed to start netbird client")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Notify management that tunnel is now active
|
||||
// Mark client as started and notify all registered domains
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.started = true
|
||||
}
|
||||
// Copy domain info while holding lock
|
||||
var domainsToNotify []struct {
|
||||
domain domain.Domain
|
||||
reverseProxyID string
|
||||
}
|
||||
if exists {
|
||||
for dom, info := range entry.domains {
|
||||
domainsToNotify = append(domainsToNotify, struct {
|
||||
domain domain.Domain
|
||||
reverseProxyID string
|
||||
}{domain: dom, reverseProxyID: info.reverseProxyID})
|
||||
}
|
||||
}
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
// Notify all domains that they're connected
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, reverseProxyID, domain, true); err != nil {
|
||||
n.logger.WithField("domain", domain).WithError(err).Warn("Failed to notify management about tunnel connection")
|
||||
} else {
|
||||
n.logger.WithField("domain", domain).Info("Successfully notified management about tunnel connection")
|
||||
for _, domInfo := range domainsToNotify {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(domInfo.domain), true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": domInfo.domain,
|
||||
}).WithError(err).Warn("failed to notify tunnel connection status")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": domInfo.domain,
|
||||
}).Info("notified management about tunnel connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
n.clients[domain] = client
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, domain, accountID, reverseProxyID string) error {
|
||||
n.clientsMux.RLock()
|
||||
client, exists := n.clients[domain]
|
||||
n.clientsMux.RUnlock()
|
||||
// RemovePeer unregisters a domain from an account. The client is only stopped
|
||||
// when no domains are using it anymore.
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error {
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
// Mission failed successfully!
|
||||
n.clientsMux.Unlock()
|
||||
return nil
|
||||
}
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
return fmt.Errorf("stop netbird client: %w", err)
|
||||
|
||||
// Get domain info before deleting
|
||||
domInfo, domainExists := entry.domains[d]
|
||||
if !domainExists {
|
||||
n.clientsMux.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Notify management that tunnel is disconnected
|
||||
delete(entry.domains, d)
|
||||
|
||||
// If there are still domains using this client, keep it running
|
||||
if len(entry.domains) > 0 {
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
"remaining_domains": len(entry.domains),
|
||||
}).Debug("unregistered domain, client still in use")
|
||||
|
||||
// Notify this domain as disconnected
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(d), false); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// No more domains using this client, stop it
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).Info("stopping client, no more domains")
|
||||
|
||||
client := entry.client
|
||||
transport := entry.transport
|
||||
delete(n.clients, accountID)
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
// Notify disconnection before stopping
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, accountID, reverseProxyID, domain, false); err != nil {
|
||||
n.logger.WithField("domain", domain).WithError(err).Warn("Failed to notify management about tunnel disconnection")
|
||||
} else {
|
||||
n.logger.WithField("domain", domain).Info("Successfully notified management about tunnel disconnection")
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.reverseProxyID, string(d), false); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
||||
}
|
||||
}
|
||||
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
delete(n.clients, domain)
|
||||
transport.CloseIdleConnections()
|
||||
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper. It looks up the client for the account
|
||||
// specified in the request context and uses it to dial the backend.
|
||||
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
host, _, err := net.SplitHostPort(req.Host)
|
||||
if err != nil {
|
||||
host = req.Host
|
||||
accountID := AccountIDFromContext(req.Context())
|
||||
if accountID == "" {
|
||||
return nil, ErrNoAccountID
|
||||
}
|
||||
|
||||
// Copy references while holding lock, then unlock early to avoid blocking
|
||||
// other requests during the potentially slow RoundTrip.
|
||||
n.clientsMux.RLock()
|
||||
client, exists := n.clients[host]
|
||||
// Immediately unlock after retrieval here rather than defer to avoid
|
||||
// the call to client.Do blocking other clients being used whilst one
|
||||
// is in use.
|
||||
n.clientsMux.RUnlock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no peer connection found for host: %s", host)
|
||||
n.clientsMux.RUnlock()
|
||||
return nil, fmt.Errorf("no peer connection found for account: %s", accountID)
|
||||
}
|
||||
client := entry.client
|
||||
transport := entry.transport
|
||||
n.clientsMux.RUnlock()
|
||||
|
||||
// Attempt to start the client, if the client is already running then
|
||||
// it will return an error that we ignore, if this hits a timeout then
|
||||
// this request is unprocessable.
|
||||
startCtx, cancel := context.WithTimeout(req.Context(), 3*time.Second)
|
||||
startCtx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
err = client.Start(startCtx)
|
||||
switch {
|
||||
case errors.Is(err, embed.ErrClientAlreadyStarted):
|
||||
break
|
||||
case err != nil:
|
||||
return nil, fmt.Errorf("start netbird client: %w", err)
|
||||
if err := client.Start(startCtx); err != nil {
|
||||
if !errors.Is(err, embed.ErrClientAlreadyStarted) {
|
||||
return nil, fmt.Errorf("start netbird client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"host": host,
|
||||
"account_id": accountID,
|
||||
"host": req.Host,
|
||||
"url": req.URL.String(),
|
||||
"requestURI": req.RequestURI,
|
||||
"method": req.Method,
|
||||
}).Debug("running roundtrip for peer connection")
|
||||
|
||||
// Create a new transport using the client dialer and perform the roundtrip.
|
||||
// We do this instead of using the client HTTPClient to avoid issues around
|
||||
// client request validation that do not work with the reverse proxied
|
||||
// requests.
|
||||
// Other values are simply copied from the http.DefaultTransport which the
|
||||
// standard reverse proxy implementation would have used.
|
||||
// TODO: tune this transport for our needs.
|
||||
return (&http.Transport{
|
||||
DialContext: client.DialContext,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}).RoundTrip(req)
|
||||
return transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// StopAll stops all clients.
|
||||
func (n *NetBird) StopAll(ctx context.Context) error {
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for accountID, entry := range n.clients {
|
||||
entry.transport.CloseIdleConnections()
|
||||
if err := entry.client.Stop(ctx); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).WithError(err).Warn("failed to stop netbird client during shutdown")
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
maps.Clear(n.clients)
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// HasClient returns true if there is a client for the given account.
|
||||
func (n *NetBird) HasClient(accountID types.AccountID) bool {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
_, exists := n.clients[accountID]
|
||||
return exists
|
||||
}
|
||||
|
||||
// DomainCount returns the number of domains registered for the given account.
|
||||
// Returns 0 if the account has no client.
|
||||
func (n *NetBird) DomainCount(accountID types.AccountID) int {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
return len(entry.domains)
|
||||
}
|
||||
|
||||
// ClientCount returns the total number of active clients.
|
||||
func (n *NetBird) ClientCount() int {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
return len(n.clients)
|
||||
}
|
||||
|
||||
// GetClient returns the embed.Client for the given account ID.
|
||||
func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
return entry.client, true
|
||||
}
|
||||
|
||||
// ClientDebugInfo contains debug information about a client.
|
||||
type ClientDebugInfo struct {
|
||||
AccountID types.AccountID
|
||||
DomainCount int
|
||||
Domains domain.List
|
||||
HasClient bool
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// ListClientsForDebug returns information about all clients for debug purposes.
|
||||
func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
|
||||
result := make(map[types.AccountID]ClientDebugInfo)
|
||||
for accountID, entry := range n.clients {
|
||||
domains := make(domain.List, 0, len(entry.domains))
|
||||
for d := range entry.domains {
|
||||
domains = append(domains, d)
|
||||
}
|
||||
result[accountID] = ClientDebugInfo{
|
||||
AccountID: accountID,
|
||||
DomainCount: len(entry.domains),
|
||||
Domains: domains,
|
||||
HasClient: entry.client != nil,
|
||||
CreatedAt: entry.createdAt,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// accountIDContextKey is the context key for storing the account ID.
|
||||
type accountIDContextKey struct{}
|
||||
|
||||
// WithAccountID adds the account ID to the context.
|
||||
func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context {
|
||||
return context.WithValue(ctx, accountIDContextKey{}, accountID)
|
||||
}
|
||||
|
||||
// AccountIDFromContext retrieves the account ID from the context.
|
||||
func AccountIDFromContext(ctx context.Context) types.AccountID {
|
||||
v := ctx.Value(accountIDContextKey{})
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
accountID, ok := v.(types.AccountID)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return accountID
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user