[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)

This commit is contained in:
Viktor Liu
2026-03-14 01:36:44 +08:00
committed by GitHub
parent fe9b844511
commit 3e6baea405
90 changed files with 9611 additions and 1397 deletions

View File

@@ -5,6 +5,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"sync"
"time"
@@ -14,11 +15,12 @@ import (
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/embed"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
@@ -26,7 +28,22 @@ import (
const deviceNamePrefix = "ingress-proxy-"
// backendKey identifies a backend by its host:port from the target URL.
type backendKey = string
type backendKey string
// ServiceKey uniquely identifies a service (HTTP reverse proxy or L4 service)
// that holds a reference to an embedded NetBird client. Callers should use the
// DomainServiceKey and L4ServiceKey constructors to avoid namespace collisions.
type ServiceKey string
// DomainServiceKey returns a ServiceKey for an HTTP/TLS domain-based service.
func DomainServiceKey(domain string) ServiceKey {
return ServiceKey("domain:" + domain)
}
// L4ServiceKey returns a ServiceKey for an L4 service (TCP/UDP).
func L4ServiceKey(id types.ServiceID) ServiceKey {
return ServiceKey("l4:" + id)
}
var (
// ErrNoAccountID is returned when a request context is missing the account ID.
@@ -39,24 +56,24 @@ var (
ErrTooManyInflight = errors.New("too many in-flight requests")
)
// domainInfo holds metadata about a registered domain.
type domainInfo struct {
serviceID string
// serviceInfo holds metadata about a registered service.
type serviceInfo struct {
serviceID types.ServiceID
}
type domainNotification struct {
domain domain.Domain
serviceID string
type serviceNotification struct {
key ServiceKey
serviceID types.ServiceID
}
// clientEntry holds an embedded NetBird client and tracks which domains use it.
// clientEntry holds an embedded NetBird client and tracks which services use it.
type clientEntry struct {
client *embed.Client
transport *http.Transport
// insecureTransport is a clone of transport with TLS verification disabled,
// used when per-target skip_tls_verify is set.
insecureTransport *http.Transport
domains map[domain.Domain]domainInfo
services map[ServiceKey]serviceInfo
createdAt time.Time
started bool
// Per-backend in-flight limiting keyed by target host:port.
@@ -93,12 +110,12 @@ func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bo
// ClientConfig holds configuration for the embedded NetBird client.
type ClientConfig struct {
MgmtAddr string
WGPort int
WGPort uint16
PreSharedKey string
}
type statusNotifier interface {
NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error
NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error
}
type managementClient interface {
@@ -107,7 +124,7 @@ type managementClient interface {
// NetBird provides an http.RoundTripper implementation
// backed by underlying NetBird connections.
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
// Clients are keyed by AccountID, allowing multiple services to share the same connection.
type NetBird struct {
proxyID string
proxyAddr string
@@ -124,11 +141,11 @@ type NetBird struct {
// ClientDebugInfo contains debug information about a client.
type ClientDebugInfo struct {
AccountID types.AccountID
DomainCount int
Domains domain.List
HasClient bool
CreatedAt time.Time
AccountID types.AccountID
ServiceCount int
ServiceKeys []string
HasClient bool
CreatedAt time.Time
}
// accountIDContextKey is the context key for storing the account ID.
@@ -137,37 +154,37 @@ type accountIDContextKey struct{}
// skipTLSVerifyContextKey is the context key for requesting insecure TLS.
type skipTLSVerifyContextKey struct{}
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
// AddPeer registers a service for an account. If the account doesn't have a client yet,
// one is created by authenticating with the management server using the provided token.
// Multiple domains can share the same client.
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) error {
// Multiple services can share the same client.
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
si := serviceInfo{serviceID: serviceID}
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
// Client already exists for this account, just register the domain
entry.domains[d] = domainInfo{serviceID: serviceID}
entry.services[key] = si
started := entry.started
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Debug("registered domain with existing client")
"account_id": accountID,
"service_key": key,
}).Debug("registered service with existing client")
// If client is already started, notify this domain as connected immediately
if started && n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), serviceID, string(d), true); err != nil {
if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
"account_id": accountID,
"service_key": key,
}).WithError(err).Warn("failed to notify status for existing client")
}
}
return nil
}
entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID)
entry, err := n.createClientEntry(ctx, accountID, key, authToken, si)
if err != nil {
n.clientsMux.Unlock()
return err
@@ -177,8 +194,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
"account_id": accountID,
"service_key": key,
}).Info("created new client for account")
// Attempt to start the client in the background; if this fails we will
@@ -190,7 +207,8 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d doma
// createClientEntry generates a WireGuard keypair, authenticates with management,
// and creates an embedded NetBird client. Must be called with clientsMux held.
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) {
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
serviceID := si.serviceID
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": serviceID,
@@ -209,7 +227,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
}).Debug("authenticating new proxy peer with management")
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
ServiceId: serviceID,
ServiceId: string(serviceID),
AccountId: string(accountID),
Token: authToken,
WireguardPublicKey: publicKey.String(),
@@ -240,13 +258,14 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
// Create embedded NetBird client with the generated private key.
// The peer has already been created via CreateProxyPeer RPC with the public key.
wgPort := int(n.clientCfg.WGPort)
client, err := embed.New(embed.Options{
DeviceName: deviceNamePrefix + n.proxyID,
ManagementURL: n.clientCfg.MgmtAddr,
PrivateKey: privateKey.String(),
LogLevel: log.WarnLevel.String(),
BlockInbound: true,
WireguardPort: &n.clientCfg.WGPort,
WireguardPort: &wgPort,
PreSharedKey: n.clientCfg.PreSharedKey,
})
if err != nil {
@@ -257,7 +276,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
// the client's HTTPClient to avoid issues with request validation that do
// not work with reverse proxied requests.
transport := &http.Transport{
DialContext: client.DialContext,
DialContext: dialWithTimeout(client.DialContext),
ForceAttemptHTTP2: true,
MaxIdleConns: n.transportCfg.maxIdleConns,
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
@@ -276,7 +295,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
return &clientEntry{
client: client,
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
services: map[ServiceKey]serviceInfo{key: si},
transport: transport,
insecureTransport: insecureTransport,
createdAt: time.Now(),
@@ -286,7 +305,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
}, nil
}
// runClientStartup starts the client and notifies registered domains on success.
// runClientStartup starts the client and notifies registered services on success.
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
@@ -300,16 +319,16 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
return
}
// Mark client as started and collect domains to notify outside the lock.
// Mark client as started and collect services to notify outside the lock.
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
entry.started = true
}
var domainsToNotify []domainNotification
var toNotify []serviceNotification
if exists {
for dom, info := range entry.domains {
domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID})
for key, info := range entry.services {
toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID})
}
}
n.clientsMux.Unlock()
@@ -317,24 +336,24 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
if n.statusNotifier == nil {
return
}
for _, dn := range domainsToNotify {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil {
for _, sn := range toNotify {
if err := n.statusNotifier.NotifyStatus(ctx, accountID, sn.serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": dn.domain,
"account_id": accountID,
"service_key": sn.key,
}).WithError(err).Warn("failed to notify tunnel connection status")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": dn.domain,
"account_id": accountID,
"service_key": sn.key,
}).Info("notified management about tunnel connection")
}
}
}
// RemovePeer unregisters a domain from an account. The client is only stopped
// when no domains are using it anymore.
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error {
// RemovePeer unregisters a service from an account. The client is only stopped
// when no services are using it anymore.
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error {
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
@@ -344,74 +363,65 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
return nil
}
// Get domain info before deleting
domInfo, domainExists := entry.domains[d]
if !domainExists {
si, svcExists := entry.services[key]
if !svcExists {
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Debug("remove peer: domain not registered")
"account_id": accountID,
"service_key": key,
}).Debug("remove peer: service not registered")
return nil
}
delete(entry.domains, d)
// If there are still domains using this client, keep it running
if len(entry.domains) > 0 {
n.clientsMux.Unlock()
delete(entry.services, key)
stopClient := len(entry.services) == 0
var client *embed.Client
var transport, insecureTransport *http.Transport
if stopClient {
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
client = entry.client
transport = entry.transport
insecureTransport = entry.insecureTransport
delete(n.clients, accountID)
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
"remaining_domains": len(entry.domains),
}).Debug("unregistered domain, client still in use")
// Notify this domain as disconnected
if n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).WithError(err).Warn("failed to notify tunnel disconnection status")
}
}
return nil
"account_id": accountID,
"service_key": key,
"remaining_services": len(entry.services),
}).Debug("unregistered service, client still in use")
}
// No more domains using this client, stop it
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).Info("stopping client, no more domains")
client := entry.client
transport := entry.transport
insecureTransport := entry.insecureTransport
delete(n.clients, accountID)
n.clientsMux.Unlock()
// Notify disconnection before stopping
if n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).WithError(err).Warn("failed to notify tunnel disconnection status")
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
if stopClient {
transport.CloseIdleConnections()
insecureTransport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
}
}
transport.CloseIdleConnections()
insecureTransport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).WithError(err).Warn("failed to stop netbird client")
}
return nil
}
func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) {
if n.statusNotifier == nil {
return
}
if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, false); err != nil {
if s, ok := grpcstatus.FromError(err); ok && s.Code() == codes.NotFound {
n.logger.WithField("service_key", key).Debug("service already removed, skipping disconnect notification")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).WithError(err).Warn("failed to notify tunnel disconnection status")
}
}
}
// RoundTrip implements http.RoundTripper. It looks up the client for the account
// specified in the request context and uses it to dial the backend.
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -435,7 +445,7 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
}
n.clientsMux.RUnlock()
release, ok := entry.acquireInflight(req.URL.Host)
release, ok := entry.acquireInflight(backendKey(req.URL.Host))
defer release()
if !ok {
return nil, ErrTooManyInflight
@@ -496,16 +506,16 @@ func (n *NetBird) HasClient(accountID types.AccountID) bool {
return exists
}
// DomainCount returns the number of domains registered for the given account.
// ServiceCount returns the number of services registered for the given account.
// Returns 0 if the account has no client.
func (n *NetBird) DomainCount(accountID types.AccountID) int {
func (n *NetBird) ServiceCount(accountID types.AccountID) int {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
entry, exists := n.clients[accountID]
if !exists {
return 0
}
return len(entry.domains)
return len(entry.services)
}
// ClientCount returns the total number of active clients.
@@ -533,16 +543,16 @@ func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
result := make(map[types.AccountID]ClientDebugInfo)
for accountID, entry := range n.clients {
domains := make(domain.List, 0, len(entry.domains))
for d := range entry.domains {
domains = append(domains, d)
keys := make([]string, 0, len(entry.services))
for k := range entry.services {
keys = append(keys, string(k))
}
result[accountID] = ClientDebugInfo{
AccountID: accountID,
DomainCount: len(entry.domains),
Domains: domains,
HasClient: entry.client != nil,
CreatedAt: entry.createdAt,
AccountID: accountID,
ServiceCount: len(entry.services),
ServiceKeys: keys,
HasClient: entry.client != nil,
CreatedAt: entry.createdAt,
}
}
return result
@@ -581,6 +591,20 @@ func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.L
}
}
// dialWithTimeout wraps a DialContext function so that any dial timeout
// stored in the context (via types.WithDialTimeout) is applied only to
// the connection establishment phase, not the full request lifetime.
func dialWithTimeout(dial func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
if d, ok := types.DialTimeoutFromContext(ctx); ok {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, d)
defer cancel()
}
return dial(ctx, network, addr)
}
}
// WithAccountID adds the account ID to the context.
func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context {
return context.WithValue(ctx, accountIDContextKey{}, accountID)