package roundtrip import ( "context" "crypto/tls" "errors" "fmt" "net" "net/http" "sync" "time" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/codes" grpcstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/embed" nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util" ) const deviceNamePrefix = "ingress-proxy-" // backendKey identifies a backend by its host:port from the target URL. type backendKey string // ServiceKey uniquely identifies a service (HTTP reverse proxy or L4 service) // that holds a reference to an embedded NetBird client. Callers should use the // DomainServiceKey and L4ServiceKey constructors to avoid namespace collisions. type ServiceKey string // DomainServiceKey returns a ServiceKey for an HTTP/TLS domain-based service. func DomainServiceKey(domain string) ServiceKey { return ServiceKey("domain:" + domain) } // L4ServiceKey returns a ServiceKey for an L4 service (TCP/UDP). func L4ServiceKey(id types.ServiceID) ServiceKey { return ServiceKey("l4:" + id) } var ( // ErrNoAccountID is returned when a request context is missing the account ID. ErrNoAccountID = errors.New("no account ID in request context") // ErrNoPeerConnection is returned when no embedded client exists for the account. ErrNoPeerConnection = errors.New("no peer connection found") // ErrClientStartFailed is returned when the embedded client fails to start. ErrClientStartFailed = errors.New("client start failed") // ErrTooManyInflight is returned when the per-backend in-flight limit is reached. ErrTooManyInflight = errors.New("too many in-flight requests") ) // serviceInfo holds metadata about a registered service. type serviceInfo struct { serviceID types.ServiceID } type serviceNotification struct { key ServiceKey serviceID types.ServiceID } // clientEntry holds an embedded NetBird client and tracks which services use it. type clientEntry struct { client *embed.Client transport *http.Transport // insecureTransport is a clone of transport with TLS verification disabled, // used when per-target skip_tls_verify is set. insecureTransport *http.Transport services map[ServiceKey]serviceInfo createdAt time.Time started bool // Per-backend in-flight limiting keyed by target host:port. // TODO: clean up stale entries when backend targets change. inflightMu sync.Mutex inflightMap map[backendKey]chan struct{} maxInflight int } // acquireInflight attempts to acquire an in-flight slot for the given backend. // It returns a release function that must always be called, and true on success. func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bool) { noop := func() {} if e.maxInflight <= 0 { return noop, true } e.inflightMu.Lock() sem, exists := e.inflightMap[backend] if !exists { sem = make(chan struct{}, e.maxInflight) e.inflightMap[backend] = sem } e.inflightMu.Unlock() select { case sem <- struct{}{}: return func() { <-sem }, true default: return noop, false } } // ClientConfig holds configuration for the embedded NetBird client. type ClientConfig struct { MgmtAddr string WGPort uint16 PreSharedKey string } type statusNotifier interface { NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error } type managementClient interface { CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest, opts ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) } // NetBird provides an http.RoundTripper implementation // backed by underlying NetBird connections. // Clients are keyed by AccountID, allowing multiple services to share the same connection. type NetBird struct { proxyID string proxyAddr string clientCfg ClientConfig logger *log.Logger mgmtClient managementClient transportCfg transportConfig clientsMux sync.RWMutex clients map[types.AccountID]*clientEntry initLogOnce sync.Once statusNotifier statusNotifier } // ClientDebugInfo contains debug information about a client. type ClientDebugInfo struct { AccountID types.AccountID ServiceCount int ServiceKeys []string HasClient bool CreatedAt time.Time } // accountIDContextKey is the context key for storing the account ID. type accountIDContextKey struct{} // skipTLSVerifyContextKey is the context key for requesting insecure TLS. type skipTLSVerifyContextKey struct{} // AddPeer registers a service for an account. If the account doesn't have a client yet, // one is created by authenticating with the management server using the provided token. // Multiple services can share the same client. func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error { si := serviceInfo{serviceID: serviceID} n.clientsMux.Lock() entry, exists := n.clients[accountID] if exists { entry.services[key] = si started := entry.started n.clientsMux.Unlock() n.logger.WithFields(log.Fields{ "account_id": accountID, "service_key": key, }).Debug("registered service with existing client") if started && n.statusNotifier != nil { if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, true); err != nil { n.logger.WithFields(log.Fields{ "account_id": accountID, "service_key": key, }).WithError(err).Warn("failed to notify status for existing client") } } return nil } entry, err := n.createClientEntry(ctx, accountID, key, authToken, si) if err != nil { n.clientsMux.Unlock() return err } n.clients[accountID] = entry n.clientsMux.Unlock() n.logger.WithFields(log.Fields{ "account_id": accountID, "service_key": key, }).Info("created new client for account") // Attempt to start the client in the background; if this fails we will // retry on the first request via RoundTrip. go n.runClientStartup(ctx, accountID, entry.client) return nil } // createClientEntry generates a WireGuard keypair, authenticates with management, // and creates an embedded NetBird client. Must be called with clientsMux held. func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) { serviceID := si.serviceID n.logger.WithFields(log.Fields{ "account_id": accountID, "service_id": serviceID, }).Debug("generating WireGuard keypair for new peer") privateKey, err := wgtypes.GeneratePrivateKey() if err != nil { return nil, fmt.Errorf("generate wireguard private key: %w", err) } publicKey := privateKey.PublicKey() n.logger.WithFields(log.Fields{ "account_id": accountID, "service_id": serviceID, "public_key": publicKey.String(), }).Debug("authenticating new proxy peer with management") resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{ ServiceId: string(serviceID), AccountId: string(accountID), Token: authToken, WireguardPublicKey: publicKey.String(), Cluster: n.proxyAddr, }) if err != nil { return nil, fmt.Errorf("authenticate proxy peer with management: %w", err) } if resp != nil && !resp.GetSuccess() { errMsg := "unknown error" if resp.ErrorMessage != nil { errMsg = *resp.ErrorMessage } return nil, fmt.Errorf("proxy peer authentication failed: %s", errMsg) } n.logger.WithFields(log.Fields{ "account_id": accountID, "service_id": serviceID, "public_key": publicKey.String(), }).Info("proxy peer authenticated successfully with management") n.initLogOnce.Do(func() { if err := util.InitLog(log.WarnLevel.String(), util.LogConsole); err != nil { n.logger.WithField("account_id", accountID).Warnf("failed to initialize embedded client logging: %v", err) } }) // Create embedded NetBird client with the generated private key. // The peer has already been created via CreateProxyPeer RPC with the public key. wgPort := int(n.clientCfg.WGPort) client, err := embed.New(embed.Options{ DeviceName: deviceNamePrefix + n.proxyID, ManagementURL: n.clientCfg.MgmtAddr, PrivateKey: privateKey.String(), LogLevel: log.WarnLevel.String(), BlockInbound: true, WireguardPort: &wgPort, PreSharedKey: n.clientCfg.PreSharedKey, }) if err != nil { return nil, fmt.Errorf("create netbird client: %w", err) } // Create a transport using the client dialer. We do this instead of using // the client's HTTPClient to avoid issues with request validation that do // not work with reverse proxied requests. transport := &http.Transport{ DialContext: dialWithTimeout(client.DialContext), ForceAttemptHTTP2: true, MaxIdleConns: n.transportCfg.maxIdleConns, MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost, MaxConnsPerHost: n.transportCfg.maxConnsPerHost, IdleConnTimeout: n.transportCfg.idleConnTimeout, TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout, ExpectContinueTimeout: n.transportCfg.expectContinueTimeout, ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout, WriteBufferSize: n.transportCfg.writeBufferSize, ReadBufferSize: n.transportCfg.readBufferSize, DisableCompression: n.transportCfg.disableCompression, } insecureTransport := transport.Clone() insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec return &clientEntry{ client: client, services: map[ServiceKey]serviceInfo{key: si}, transport: transport, insecureTransport: insecureTransport, createdAt: time.Now(), started: false, inflightMap: make(map[backendKey]chan struct{}), maxInflight: n.transportCfg.maxInflight, }, nil } // runClientStartup starts the client and notifies registered services on success. func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) { startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() if err := client.Start(startCtx); err != nil { if errors.Is(err, context.DeadlineExceeded) { n.logger.WithField("account_id", accountID).Warn("netbird client start timed out, will retry on first request") } else { n.logger.WithField("account_id", accountID).WithError(err).Error("failed to start netbird client") } return } // Mark client as started and collect services to notify outside the lock. n.clientsMux.Lock() entry, exists := n.clients[accountID] if exists { entry.started = true } var toNotify []serviceNotification if exists { for key, info := range entry.services { toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID}) } } n.clientsMux.Unlock() if n.statusNotifier == nil { return } for _, sn := range toNotify { if err := n.statusNotifier.NotifyStatus(ctx, accountID, sn.serviceID, true); err != nil { n.logger.WithFields(log.Fields{ "account_id": accountID, "service_key": sn.key, }).WithError(err).Warn("failed to notify tunnel connection status") } else { n.logger.WithFields(log.Fields{ "account_id": accountID, "service_key": sn.key, }).Info("notified management about tunnel connection") } } } // RemovePeer unregisters a service from an account. The client is only stopped // when no services are using it anymore. func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error { n.clientsMux.Lock() entry, exists := n.clients[accountID] if !exists { n.clientsMux.Unlock() n.logger.WithField("account_id", accountID).Debug("remove peer: account not found") return nil } si, svcExists := entry.services[key] if !svcExists { n.clientsMux.Unlock() n.logger.WithFields(log.Fields{ "account_id": accountID, "service_key": key, }).Debug("remove peer: service not registered") return nil } delete(entry.services, key) stopClient := len(entry.services) == 0 var client *embed.Client var transport, insecureTransport *http.Transport if stopClient { n.logger.WithField("account_id", accountID).Info("stopping client, no more services") client = entry.client transport = entry.transport insecureTransport = entry.insecureTransport delete(n.clients, accountID) } else { n.logger.WithFields(log.Fields{ "account_id": accountID, "service_key": key, "remaining_services": len(entry.services), }).Debug("unregistered service, client still in use") } n.clientsMux.Unlock() n.notifyDisconnect(ctx, accountID, key, si.serviceID) if stopClient { transport.CloseIdleConnections() insecureTransport.CloseIdleConnections() if err := client.Stop(ctx); err != nil { n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client") } } return nil } func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) { if n.statusNotifier == nil { return } if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, false); err != nil { if s, ok := grpcstatus.FromError(err); ok && s.Code() == codes.NotFound { n.logger.WithField("service_key", key).Debug("service already removed, skipping disconnect notification") } else { n.logger.WithFields(log.Fields{ "account_id": accountID, "service_key": key, }).WithError(err).Warn("failed to notify tunnel disconnection status") } } } // RoundTrip implements http.RoundTripper. It looks up the client for the account // specified in the request context and uses it to dial the backend. func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) { accountID := AccountIDFromContext(req.Context()) if accountID == "" { return nil, ErrNoAccountID } // Copy references while holding lock, then unlock early to avoid blocking // other requests during the potentially slow RoundTrip. n.clientsMux.RLock() entry, exists := n.clients[accountID] if !exists { n.clientsMux.RUnlock() return nil, fmt.Errorf("%w for account: %s", ErrNoPeerConnection, accountID) } client := entry.client transport := entry.transport if skipTLSVerifyFromContext(req.Context()) { transport = entry.insecureTransport } n.clientsMux.RUnlock() release, ok := entry.acquireInflight(backendKey(req.URL.Host)) defer release() if !ok { return nil, ErrTooManyInflight } // Attempt to start the client, if the client is already running then // it will return an error that we ignore, if this hits a timeout then // this request is unprocessable. startCtx, cancel := context.WithTimeout(req.Context(), 30*time.Second) defer cancel() if err := client.Start(startCtx); err != nil { if !errors.Is(err, embed.ErrClientAlreadyStarted) { return nil, fmt.Errorf("%w: %w", ErrClientStartFailed, err) } } start := time.Now() resp, err := transport.RoundTrip(req) duration := time.Since(start) if err != nil { n.logger.Debugf("roundtrip: method=%s host=%s url=%s account=%s duration=%s err=%v", req.Method, req.Host, req.URL.String(), accountID, duration.Truncate(time.Millisecond), err) return nil, err } n.logger.Debugf("roundtrip: method=%s host=%s url=%s account=%s status=%d duration=%s", req.Method, req.Host, req.URL.String(), accountID, resp.StatusCode, duration.Truncate(time.Millisecond)) return resp, nil } // StopAll stops all clients. func (n *NetBird) StopAll(ctx context.Context) error { n.clientsMux.Lock() defer n.clientsMux.Unlock() var merr *multierror.Error for accountID, entry := range n.clients { entry.transport.CloseIdleConnections() entry.insecureTransport.CloseIdleConnections() if err := entry.client.Stop(ctx); err != nil { n.logger.WithFields(log.Fields{ "account_id": accountID, }).WithError(err).Warn("failed to stop netbird client during shutdown") merr = multierror.Append(merr, err) } } maps.Clear(n.clients) return nberrors.FormatErrorOrNil(merr) } // HasClient returns true if there is a client for the given account. func (n *NetBird) HasClient(accountID types.AccountID) bool { n.clientsMux.RLock() defer n.clientsMux.RUnlock() _, exists := n.clients[accountID] return exists } // ServiceCount returns the number of services registered for the given account. // Returns 0 if the account has no client. func (n *NetBird) ServiceCount(accountID types.AccountID) int { n.clientsMux.RLock() defer n.clientsMux.RUnlock() entry, exists := n.clients[accountID] if !exists { return 0 } return len(entry.services) } // ClientCount returns the total number of active clients. func (n *NetBird) ClientCount() int { n.clientsMux.RLock() defer n.clientsMux.RUnlock() return len(n.clients) } // GetClient returns the embed.Client for the given account ID. func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) { n.clientsMux.RLock() defer n.clientsMux.RUnlock() entry, exists := n.clients[accountID] if !exists { return nil, false } return entry.client, true } // ListClientsForDebug returns information about all clients for debug purposes. func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo { n.clientsMux.RLock() defer n.clientsMux.RUnlock() result := make(map[types.AccountID]ClientDebugInfo) for accountID, entry := range n.clients { keys := make([]string, 0, len(entry.services)) for k := range entry.services { keys = append(keys, string(k)) } result[accountID] = ClientDebugInfo{ AccountID: accountID, ServiceCount: len(entry.services), ServiceKeys: keys, HasClient: entry.client != nil, CreatedAt: entry.createdAt, } } return result } // ListClientsForStartup returns all embed.Client instances for health checks. func (n *NetBird) ListClientsForStartup() map[types.AccountID]*embed.Client { n.clientsMux.RLock() defer n.clientsMux.RUnlock() result := make(map[types.AccountID]*embed.Client) for accountID, entry := range n.clients { if entry.client != nil { result[accountID] = entry.client } } return result } // NewNetBird creates a new NetBird transport. Set clientCfg.WGPort to 0 for a random // OS-assigned port. A fixed port only works with single-account deployments; // multiple accounts will fail to bind the same port. func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird { if logger == nil { logger = log.StandardLogger() } return &NetBird{ proxyID: proxyID, proxyAddr: proxyAddr, clientCfg: clientCfg, logger: logger, clients: make(map[types.AccountID]*clientEntry), statusNotifier: notifier, mgmtClient: mgmtClient, transportCfg: loadTransportConfig(logger), } } // dialWithTimeout wraps a DialContext function so that any dial timeout // stored in the context (via types.WithDialTimeout) is applied only to // the connection establishment phase, not the full request lifetime. func dialWithTimeout(dial func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) { return func(ctx context.Context, network, addr string) (net.Conn, error) { if d, ok := types.DialTimeoutFromContext(ctx); ok { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, d) defer cancel() } return dial(ctx, network, addr) } } // WithAccountID adds the account ID to the context. func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context { return context.WithValue(ctx, accountIDContextKey{}, accountID) } // AccountIDFromContext retrieves the account ID from the context. func AccountIDFromContext(ctx context.Context) types.AccountID { v := ctx.Value(accountIDContextKey{}) if v == nil { return "" } accountID, ok := v.(types.AccountID) if !ok { return "" } return accountID } // WithSkipTLSVerify marks the context to use an insecure transport that skips // TLS certificate verification for the backend connection. func WithSkipTLSVerify(ctx context.Context) context.Context { return context.WithValue(ctx, skipTLSVerifyContextKey{}, true) } func skipTLSVerifyFromContext(ctx context.Context) bool { v, _ := ctx.Value(skipTLSVerifyContextKey{}).(bool) return v }