mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
636 lines
20 KiB
Go
636 lines
20 KiB
Go
package roundtrip
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-multierror"
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/exp/maps"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
grpcstatus "google.golang.org/grpc/status"
|
|
|
|
"github.com/netbirdio/netbird/client/embed"
|
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
"github.com/netbirdio/netbird/proxy/internal/types"
|
|
"github.com/netbirdio/netbird/shared/management/proto"
|
|
"github.com/netbirdio/netbird/util"
|
|
)
|
|
|
|
const deviceNamePrefix = "ingress-proxy-"
|
|
|
|
// backendKey identifies a backend by its host:port from the target URL.
|
|
type backendKey string
|
|
|
|
// ServiceKey uniquely identifies a service (HTTP reverse proxy or L4 service)
|
|
// that holds a reference to an embedded NetBird client. Callers should use the
|
|
// DomainServiceKey and L4ServiceKey constructors to avoid namespace collisions.
|
|
type ServiceKey string
|
|
|
|
// DomainServiceKey returns a ServiceKey for an HTTP/TLS domain-based service.
|
|
func DomainServiceKey(domain string) ServiceKey {
|
|
return ServiceKey("domain:" + domain)
|
|
}
|
|
|
|
// L4ServiceKey returns a ServiceKey for an L4 service (TCP/UDP).
|
|
func L4ServiceKey(id types.ServiceID) ServiceKey {
|
|
return ServiceKey("l4:" + id)
|
|
}
|
|
|
|
var (
|
|
// ErrNoAccountID is returned when a request context is missing the account ID.
|
|
ErrNoAccountID = errors.New("no account ID in request context")
|
|
// ErrNoPeerConnection is returned when no embedded client exists for the account.
|
|
ErrNoPeerConnection = errors.New("no peer connection found")
|
|
// ErrClientStartFailed is returned when the embedded client fails to start.
|
|
ErrClientStartFailed = errors.New("client start failed")
|
|
// ErrTooManyInflight is returned when the per-backend in-flight limit is reached.
|
|
ErrTooManyInflight = errors.New("too many in-flight requests")
|
|
)
|
|
|
|
// serviceInfo holds metadata about a registered service.
|
|
type serviceInfo struct {
|
|
serviceID types.ServiceID
|
|
}
|
|
|
|
type serviceNotification struct {
|
|
key ServiceKey
|
|
serviceID types.ServiceID
|
|
}
|
|
|
|
// clientEntry holds an embedded NetBird client and tracks which services use it.
|
|
type clientEntry struct {
|
|
client *embed.Client
|
|
transport *http.Transport
|
|
// insecureTransport is a clone of transport with TLS verification disabled,
|
|
// used when per-target skip_tls_verify is set.
|
|
insecureTransport *http.Transport
|
|
services map[ServiceKey]serviceInfo
|
|
createdAt time.Time
|
|
started bool
|
|
// Per-backend in-flight limiting keyed by target host:port.
|
|
// TODO: clean up stale entries when backend targets change.
|
|
inflightMu sync.Mutex
|
|
inflightMap map[backendKey]chan struct{}
|
|
maxInflight int
|
|
}
|
|
|
|
// acquireInflight attempts to acquire an in-flight slot for the given backend.
|
|
// It returns a release function that must always be called, and true on success.
|
|
func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bool) {
|
|
noop := func() {}
|
|
if e.maxInflight <= 0 {
|
|
return noop, true
|
|
}
|
|
|
|
e.inflightMu.Lock()
|
|
sem, exists := e.inflightMap[backend]
|
|
if !exists {
|
|
sem = make(chan struct{}, e.maxInflight)
|
|
e.inflightMap[backend] = sem
|
|
}
|
|
e.inflightMu.Unlock()
|
|
|
|
select {
|
|
case sem <- struct{}{}:
|
|
return func() { <-sem }, true
|
|
default:
|
|
return noop, false
|
|
}
|
|
}
|
|
|
|
// ClientConfig holds configuration for the embedded NetBird client.
|
|
type ClientConfig struct {
|
|
MgmtAddr string
|
|
WGPort uint16
|
|
PreSharedKey string
|
|
}
|
|
|
|
type statusNotifier interface {
|
|
NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error
|
|
}
|
|
|
|
type managementClient interface {
|
|
CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest, opts ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error)
|
|
}
|
|
|
|
// NetBird provides an http.RoundTripper implementation
|
|
// backed by underlying NetBird connections.
|
|
// Clients are keyed by AccountID, allowing multiple services to share the same connection.
|
|
type NetBird struct {
|
|
proxyID string
|
|
proxyAddr string
|
|
clientCfg ClientConfig
|
|
logger *log.Logger
|
|
mgmtClient managementClient
|
|
transportCfg transportConfig
|
|
|
|
clientsMux sync.RWMutex
|
|
clients map[types.AccountID]*clientEntry
|
|
initLogOnce sync.Once
|
|
statusNotifier statusNotifier
|
|
}
|
|
|
|
// ClientDebugInfo contains debug information about a client.
|
|
type ClientDebugInfo struct {
|
|
AccountID types.AccountID
|
|
ServiceCount int
|
|
ServiceKeys []string
|
|
HasClient bool
|
|
CreatedAt time.Time
|
|
}
|
|
|
|
// accountIDContextKey is the context key for storing the account ID.
|
|
type accountIDContextKey struct{}
|
|
|
|
// skipTLSVerifyContextKey is the context key for requesting insecure TLS.
|
|
type skipTLSVerifyContextKey struct{}
|
|
|
|
// AddPeer registers a service for an account. If the account doesn't have a client yet,
|
|
// one is created by authenticating with the management server using the provided token.
|
|
// Multiple services can share the same client.
|
|
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
|
|
si := serviceInfo{serviceID: serviceID}
|
|
|
|
n.clientsMux.Lock()
|
|
|
|
entry, exists := n.clients[accountID]
|
|
if exists {
|
|
entry.services[key] = si
|
|
started := entry.started
|
|
n.clientsMux.Unlock()
|
|
|
|
n.logger.WithFields(log.Fields{
|
|
"account_id": accountID,
|
|
"service_key": key,
|
|
}).Debug("registered service with existing client")
|
|
|
|
if started && n.statusNotifier != nil {
|
|
if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, true); err != nil {
|
|
n.logger.WithFields(log.Fields{
|
|
"account_id": accountID,
|
|
"service_key": key,
|
|
}).WithError(err).Warn("failed to notify status for existing client")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
entry, err := n.createClientEntry(ctx, accountID, key, authToken, si)
|
|
if err != nil {
|
|
n.clientsMux.Unlock()
|
|
return err
|
|
}
|
|
|
|
n.clients[accountID] = entry
|
|
n.clientsMux.Unlock()
|
|
|
|
n.logger.WithFields(log.Fields{
|
|
"account_id": accountID,
|
|
"service_key": key,
|
|
}).Info("created new client for account")
|
|
|
|
// Attempt to start the client in the background; if this fails we will
|
|
// retry on the first request via RoundTrip.
|
|
go n.runClientStartup(ctx, accountID, entry.client)
|
|
|
|
return nil
|
|
}
|
|
|
|
// createClientEntry generates a WireGuard keypair, authenticates with management,
|
|
// and creates an embedded NetBird client. Must be called with clientsMux held.
|
|
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
|
|
serviceID := si.serviceID
|
|
n.logger.WithFields(log.Fields{
|
|
"account_id": accountID,
|
|
"service_id": serviceID,
|
|
}).Debug("generating WireGuard keypair for new peer")
|
|
|
|
privateKey, err := wgtypes.GeneratePrivateKey()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generate wireguard private key: %w", err)
|
|
}
|
|
publicKey := privateKey.PublicKey()
|
|
|
|
n.logger.WithFields(log.Fields{
|
|
"account_id": accountID,
|
|
"service_id": serviceID,
|
|
"public_key": publicKey.String(),
|
|
}).Debug("authenticating new proxy peer with management")
|
|
|
|
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
|
|
ServiceId: string(serviceID),
|
|
AccountId: string(accountID),
|
|
Token: authToken,
|
|
WireguardPublicKey: publicKey.String(),
|
|
Cluster: n.proxyAddr,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("authenticate proxy peer with management: %w", err)
|
|
}
|
|
if resp != nil && !resp.GetSuccess() {
|
|
errMsg := "unknown error"
|
|
if resp.ErrorMessage != nil {
|
|
errMsg = *resp.ErrorMessage
|
|
}
|
|
return nil, fmt.Errorf("proxy peer authentication failed: %s", errMsg)
|
|
}
|
|
|
|
n.logger.WithFields(log.Fields{
|
|
"account_id": accountID,
|
|
"service_id": serviceID,
|
|
"public_key": publicKey.String(),
|
|
}).Info("proxy peer authenticated successfully with management")
|
|
|
|
n.initLogOnce.Do(func() {
|
|
if err := util.InitLog(log.WarnLevel.String(), util.LogConsole); err != nil {
|
|
n.logger.WithField("account_id", accountID).Warnf("failed to initialize embedded client logging: %v", err)
|
|
}
|
|
})
|
|
|
|
// Create embedded NetBird client with the generated private key.
|
|
// The peer has already been created via CreateProxyPeer RPC with the public key.
|
|
wgPort := int(n.clientCfg.WGPort)
|
|
client, err := embed.New(embed.Options{
|
|
DeviceName: deviceNamePrefix + n.proxyID,
|
|
ManagementURL: n.clientCfg.MgmtAddr,
|
|
PrivateKey: privateKey.String(),
|
|
LogLevel: log.WarnLevel.String(),
|
|
BlockInbound: true,
|
|
WireguardPort: &wgPort,
|
|
PreSharedKey: n.clientCfg.PreSharedKey,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create netbird client: %w", err)
|
|
}
|
|
|
|
// Create a transport using the client dialer. We do this instead of using
|
|
// the client's HTTPClient to avoid issues with request validation that do
|
|
// not work with reverse proxied requests.
|
|
transport := &http.Transport{
|
|
DialContext: dialWithTimeout(client.DialContext),
|
|
ForceAttemptHTTP2: true,
|
|
MaxIdleConns: n.transportCfg.maxIdleConns,
|
|
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
|
|
MaxConnsPerHost: n.transportCfg.maxConnsPerHost,
|
|
IdleConnTimeout: n.transportCfg.idleConnTimeout,
|
|
TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout,
|
|
ExpectContinueTimeout: n.transportCfg.expectContinueTimeout,
|
|
ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout,
|
|
WriteBufferSize: n.transportCfg.writeBufferSize,
|
|
ReadBufferSize: n.transportCfg.readBufferSize,
|
|
DisableCompression: n.transportCfg.disableCompression,
|
|
}
|
|
|
|
insecureTransport := transport.Clone()
|
|
insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec
|
|
|
|
return &clientEntry{
|
|
client: client,
|
|
services: map[ServiceKey]serviceInfo{key: si},
|
|
transport: transport,
|
|
insecureTransport: insecureTransport,
|
|
createdAt: time.Now(),
|
|
started: false,
|
|
inflightMap: make(map[backendKey]chan struct{}),
|
|
maxInflight: n.transportCfg.maxInflight,
|
|
}, nil
|
|
}
|
|
|
|
// runClientStartup starts the client and notifies registered services on success.
|
|
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
|
|
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
if err := client.Start(startCtx); err != nil {
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
n.logger.WithField("account_id", accountID).Warn("netbird client start timed out, will retry on first request")
|
|
} else {
|
|
n.logger.WithField("account_id", accountID).WithError(err).Error("failed to start netbird client")
|
|
}
|
|
return
|
|
}
|
|
|
|
// Mark client as started and collect services to notify outside the lock.
|
|
n.clientsMux.Lock()
|
|
entry, exists := n.clients[accountID]
|
|
if exists {
|
|
entry.started = true
|
|
}
|
|
var toNotify []serviceNotification
|
|
if exists {
|
|
for key, info := range entry.services {
|
|
toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID})
|
|
}
|
|
}
|
|
n.clientsMux.Unlock()
|
|
|
|
if n.statusNotifier == nil {
|
|
return
|
|
}
|
|
for _, sn := range toNotify {
|
|
if err := n.statusNotifier.NotifyStatus(ctx, accountID, sn.serviceID, true); err != nil {
|
|
n.logger.WithFields(log.Fields{
|
|
"account_id": accountID,
|
|
"service_key": sn.key,
|
|
}).WithError(err).Warn("failed to notify tunnel connection status")
|
|
} else {
|
|
n.logger.WithFields(log.Fields{
|
|
"account_id": accountID,
|
|
"service_key": sn.key,
|
|
}).Info("notified management about tunnel connection")
|
|
}
|
|
}
|
|
}
|
|
|
|
// RemovePeer unregisters a service from an account. The client is only stopped
|
|
// when no services are using it anymore.
|
|
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error {
|
|
n.clientsMux.Lock()
|
|
|
|
entry, exists := n.clients[accountID]
|
|
if !exists {
|
|
n.clientsMux.Unlock()
|
|
n.logger.WithField("account_id", accountID).Debug("remove peer: account not found")
|
|
return nil
|
|
}
|
|
|
|
si, svcExists := entry.services[key]
|
|
if !svcExists {
|
|
n.clientsMux.Unlock()
|
|
n.logger.WithFields(log.Fields{
|
|
"account_id": accountID,
|
|
"service_key": key,
|
|
}).Debug("remove peer: service not registered")
|
|
return nil
|
|
}
|
|
|
|
delete(entry.services, key)
|
|
|
|
stopClient := len(entry.services) == 0
|
|
var client *embed.Client
|
|
var transport, insecureTransport *http.Transport
|
|
if stopClient {
|
|
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
|
|
client = entry.client
|
|
transport = entry.transport
|
|
insecureTransport = entry.insecureTransport
|
|
delete(n.clients, accountID)
|
|
} else {
|
|
n.logger.WithFields(log.Fields{
|
|
"account_id": accountID,
|
|
"service_key": key,
|
|
"remaining_services": len(entry.services),
|
|
}).Debug("unregistered service, client still in use")
|
|
}
|
|
n.clientsMux.Unlock()
|
|
|
|
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
|
|
|
|
if stopClient {
|
|
transport.CloseIdleConnections()
|
|
insecureTransport.CloseIdleConnections()
|
|
if err := client.Stop(ctx); err != nil {
|
|
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) {
|
|
if n.statusNotifier == nil {
|
|
return
|
|
}
|
|
if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, false); err != nil {
|
|
if s, ok := grpcstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
|
n.logger.WithField("service_key", key).Debug("service already removed, skipping disconnect notification")
|
|
} else {
|
|
n.logger.WithFields(log.Fields{
|
|
"account_id": accountID,
|
|
"service_key": key,
|
|
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
|
}
|
|
}
|
|
}
|
|
|
|
// RoundTrip implements http.RoundTripper. It looks up the client for the account
|
|
// specified in the request context and uses it to dial the backend.
|
|
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
accountID := AccountIDFromContext(req.Context())
|
|
if accountID == "" {
|
|
return nil, ErrNoAccountID
|
|
}
|
|
|
|
// Copy references while holding lock, then unlock early to avoid blocking
|
|
// other requests during the potentially slow RoundTrip.
|
|
n.clientsMux.RLock()
|
|
entry, exists := n.clients[accountID]
|
|
if !exists {
|
|
n.clientsMux.RUnlock()
|
|
return nil, fmt.Errorf("%w for account: %s", ErrNoPeerConnection, accountID)
|
|
}
|
|
client := entry.client
|
|
transport := entry.transport
|
|
if skipTLSVerifyFromContext(req.Context()) {
|
|
transport = entry.insecureTransport
|
|
}
|
|
n.clientsMux.RUnlock()
|
|
|
|
release, ok := entry.acquireInflight(backendKey(req.URL.Host))
|
|
defer release()
|
|
if !ok {
|
|
return nil, ErrTooManyInflight
|
|
}
|
|
|
|
// Attempt to start the client, if the client is already running then
|
|
// it will return an error that we ignore, if this hits a timeout then
|
|
// this request is unprocessable.
|
|
startCtx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
|
|
defer cancel()
|
|
if err := client.Start(startCtx); err != nil {
|
|
if !errors.Is(err, embed.ErrClientAlreadyStarted) {
|
|
return nil, fmt.Errorf("%w: %w", ErrClientStartFailed, err)
|
|
}
|
|
}
|
|
|
|
start := time.Now()
|
|
resp, err := transport.RoundTrip(req)
|
|
duration := time.Since(start)
|
|
|
|
if err != nil {
|
|
n.logger.Debugf("roundtrip: method=%s host=%s url=%s account=%s duration=%s err=%v",
|
|
req.Method, req.Host, req.URL.String(), accountID, duration.Truncate(time.Millisecond), err)
|
|
return nil, err
|
|
}
|
|
|
|
n.logger.Debugf("roundtrip: method=%s host=%s url=%s account=%s status=%d duration=%s",
|
|
req.Method, req.Host, req.URL.String(), accountID, resp.StatusCode, duration.Truncate(time.Millisecond))
|
|
return resp, nil
|
|
}
|
|
|
|
// StopAll stops all clients.
|
|
func (n *NetBird) StopAll(ctx context.Context) error {
|
|
n.clientsMux.Lock()
|
|
defer n.clientsMux.Unlock()
|
|
|
|
var merr *multierror.Error
|
|
for accountID, entry := range n.clients {
|
|
entry.transport.CloseIdleConnections()
|
|
entry.insecureTransport.CloseIdleConnections()
|
|
if err := entry.client.Stop(ctx); err != nil {
|
|
n.logger.WithFields(log.Fields{
|
|
"account_id": accountID,
|
|
}).WithError(err).Warn("failed to stop netbird client during shutdown")
|
|
merr = multierror.Append(merr, err)
|
|
}
|
|
}
|
|
maps.Clear(n.clients)
|
|
|
|
return nberrors.FormatErrorOrNil(merr)
|
|
}
|
|
|
|
// HasClient returns true if there is a client for the given account.
|
|
func (n *NetBird) HasClient(accountID types.AccountID) bool {
|
|
n.clientsMux.RLock()
|
|
defer n.clientsMux.RUnlock()
|
|
_, exists := n.clients[accountID]
|
|
return exists
|
|
}
|
|
|
|
// ServiceCount returns the number of services registered for the given account.
|
|
// Returns 0 if the account has no client.
|
|
func (n *NetBird) ServiceCount(accountID types.AccountID) int {
|
|
n.clientsMux.RLock()
|
|
defer n.clientsMux.RUnlock()
|
|
entry, exists := n.clients[accountID]
|
|
if !exists {
|
|
return 0
|
|
}
|
|
return len(entry.services)
|
|
}
|
|
|
|
// ClientCount returns the total number of active clients.
|
|
func (n *NetBird) ClientCount() int {
|
|
n.clientsMux.RLock()
|
|
defer n.clientsMux.RUnlock()
|
|
return len(n.clients)
|
|
}
|
|
|
|
// GetClient returns the embed.Client for the given account ID.
|
|
func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) {
|
|
n.clientsMux.RLock()
|
|
defer n.clientsMux.RUnlock()
|
|
entry, exists := n.clients[accountID]
|
|
if !exists {
|
|
return nil, false
|
|
}
|
|
return entry.client, true
|
|
}
|
|
|
|
// ListClientsForDebug returns information about all clients for debug purposes.
|
|
func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
|
|
n.clientsMux.RLock()
|
|
defer n.clientsMux.RUnlock()
|
|
|
|
result := make(map[types.AccountID]ClientDebugInfo)
|
|
for accountID, entry := range n.clients {
|
|
keys := make([]string, 0, len(entry.services))
|
|
for k := range entry.services {
|
|
keys = append(keys, string(k))
|
|
}
|
|
result[accountID] = ClientDebugInfo{
|
|
AccountID: accountID,
|
|
ServiceCount: len(entry.services),
|
|
ServiceKeys: keys,
|
|
HasClient: entry.client != nil,
|
|
CreatedAt: entry.createdAt,
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// ListClientsForStartup returns all embed.Client instances for health checks.
|
|
func (n *NetBird) ListClientsForStartup() map[types.AccountID]*embed.Client {
|
|
n.clientsMux.RLock()
|
|
defer n.clientsMux.RUnlock()
|
|
|
|
result := make(map[types.AccountID]*embed.Client)
|
|
for accountID, entry := range n.clients {
|
|
if entry.client != nil {
|
|
result[accountID] = entry.client
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// NewNetBird creates a new NetBird transport. Set clientCfg.WGPort to 0 for a random
|
|
// OS-assigned port. A fixed port only works with single-account deployments;
|
|
// multiple accounts will fail to bind the same port.
|
|
func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird {
|
|
if logger == nil {
|
|
logger = log.StandardLogger()
|
|
}
|
|
return &NetBird{
|
|
proxyID: proxyID,
|
|
proxyAddr: proxyAddr,
|
|
clientCfg: clientCfg,
|
|
logger: logger,
|
|
clients: make(map[types.AccountID]*clientEntry),
|
|
statusNotifier: notifier,
|
|
mgmtClient: mgmtClient,
|
|
transportCfg: loadTransportConfig(logger),
|
|
}
|
|
}
|
|
|
|
// dialWithTimeout wraps a DialContext function so that any dial timeout
|
|
// stored in the context (via types.WithDialTimeout) is applied only to
|
|
// the connection establishment phase, not the full request lifetime.
|
|
func dialWithTimeout(dial func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
if d, ok := types.DialTimeoutFromContext(ctx); ok {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, d)
|
|
defer cancel()
|
|
}
|
|
return dial(ctx, network, addr)
|
|
}
|
|
}
|
|
|
|
// WithAccountID adds the account ID to the context.
|
|
func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context {
|
|
return context.WithValue(ctx, accountIDContextKey{}, accountID)
|
|
}
|
|
|
|
// AccountIDFromContext retrieves the account ID from the context.
|
|
func AccountIDFromContext(ctx context.Context) types.AccountID {
|
|
v := ctx.Value(accountIDContextKey{})
|
|
if v == nil {
|
|
return ""
|
|
}
|
|
accountID, ok := v.(types.AccountID)
|
|
if !ok {
|
|
return ""
|
|
}
|
|
return accountID
|
|
}
|
|
|
|
// WithSkipTLSVerify marks the context to use an insecure transport that skips
|
|
// TLS certificate verification for the backend connection.
|
|
func WithSkipTLSVerify(ctx context.Context) context.Context {
|
|
return context.WithValue(ctx, skipTLSVerifyContextKey{}, true)
|
|
}
|
|
|
|
func skipTLSVerifyFromContext(ctx context.Context) bool {
|
|
v, _ := ctx.Value(skipTLSVerifyContextKey{}).(bool)
|
|
return v
|
|
}
|