mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-21 08:09:55 +00:00
feat(private-service): expose NetBird-only services over tunnel peers
Adds a new "private" service mode for the reverse proxy: services reachable exclusively over the embedded WireGuard tunnel, gated by per-peer group membership instead of operator auth schemes. Wire contract - ProxyMapping.private (field 13): the proxy MUST call ValidateTunnelPeer and fail closed; operator schemes are bypassed. - ProxyCapabilities.private (4) + supports_private_service (5): capability gate. Management never streams private mappings to proxies that don't claim the capability; the broadcast path applies the same filter via filterMappingsForProxy. - ValidateTunnelPeer RPC: resolves an inbound tunnel IP to a peer, checks the peer's groups against service.AccessGroups, and mints a session JWT on success. checkPeerGroupAccess fails closed when a private service has empty AccessGroups. - ValidateSession/ValidateTunnelPeer responses now carry peer_group_ids + peer_group_names so the proxy can authorise policy-aware middlewares without an extra management round-trip. - ProxyInboundListener + SendStatusUpdate.inbound_listener: per-account inbound listener state surfaced to dashboards. - PathTargetOptions.direct_upstream (11): bypass the embedded NetBird client and dial the target via the proxy host's network stack for upstreams reachable without WireGuard. Data model - Service.Private (bool) + Service.AccessGroups ([]string, JSON- serialised). Validate() rejects bearer auth on private services. Copy() deep-copies AccessGroups. pgx getServices loads the columns. - DomainConfig.Private threaded into the proxy auth middleware. Request handler routes private services through forwardWithTunnelPeer and returns 403 on validation failure. - Account-level SynthesizePrivateServiceZones (synthetic DNS) and injectPrivateServicePolicies (synthetic ACL) gate on len(svc.AccessGroups) > 0. Proxy - /netbird proxy --private (embedded mode) flag; Config.Private in proxy/lifecycle.go. - Per-account inbound listener (proxy/inbound.go) binding HTTP/HTTPS on the embedded NetBird client's WireGuard tunnel netstack. - proxy/internal/auth/tunnel_cache: ValidateTunnelPeer response cache with single-flight de-duplication and per-account eviction. - Local peerstore short-circuit: when the inbound IP isn't in the account roster, deny fast without an RPC. - proxy/server.go reports SupportsPrivateService=true and redacts the full ProxyMapping JSON from info logs (auth_token + header-auth hashed values now only at debug level). Identity forwarding - ValidateSessionJWT returns user_id, email, method, groups, group_names. sessionkey.Claims carries Email + Groups + GroupNames so the proxy can stamp identity onto upstream requests without an extra management round-trip on every cookie-bearing request. - CapturedData carries userEmail / userGroups / userGroupNames; the proxy stamps X-NetBird-User and X-NetBird-Groups on r.Out from the authenticated identity (strips client-supplied values first to prevent spoofing). - AccessLog.UserGroups: access-log enrichment captures the user's group memberships at write time so the dashboard can render group context without reverse-resolving stale memberships. OpenAPI/dashboard surface - ReverseProxyService gains private + access_groups; ReverseProxyCluster gains private + supports_private. ReverseProxyTarget target_type enum gains "cluster". ServiceTargetOptions gains direct_upstream. ProxyAccessLog gains user_groups.
This commit is contained in:
738
proxy/server.go
738
proxy/server.go
@@ -32,11 +32,10 @@ import (
|
||||
"go.opentelemetry.io/otel/sdk/metric"
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
grpcstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
@@ -114,9 +113,28 @@ type Server struct {
|
||||
// The mapping worker waits on this before processing updates.
|
||||
routerReady chan struct{}
|
||||
|
||||
// inbound, when non-nil, manages per-account inbound listeners. Set by
|
||||
// initPrivateInbound only when Private is true so the standalone
|
||||
// proxy keeps its zero-overhead default path.
|
||||
inbound *inboundManager
|
||||
|
||||
// Lifecycle state — populated by Start, consumed by Stop. The fields
|
||||
// stay zero on a fresh Server until Start runs so direct struct
|
||||
// construction (`&Server{...}`) used by tests still works.
|
||||
runCancel context.CancelFunc
|
||||
mgmtConn *grpc.ClientConn
|
||||
runErr error
|
||||
runErrCh chan struct{}
|
||||
startMu sync.Mutex
|
||||
started bool
|
||||
stopOnce sync.Once
|
||||
|
||||
// Mostly used for debugging on management.
|
||||
startTime time.Time
|
||||
|
||||
// ListenAddr is the address the main TCP listener binds. Populated by
|
||||
// New from Config or by ListenAndServe from its addr argument.
|
||||
ListenAddr string
|
||||
ID string
|
||||
Logger *log.Logger
|
||||
Version string
|
||||
@@ -177,6 +195,14 @@ type Server struct {
|
||||
// in front of this proxy's cluster domain. When true, accounts cannot
|
||||
// create services on the bare cluster domain.
|
||||
RequireSubdomain bool
|
||||
// Private flags this proxy as embedded in a netbird client and serving
|
||||
// exclusively over the WireGuard tunnel (i.e. `netbird proxy`). Reported
|
||||
// upstream as a capability so dashboards can distinguish per-peer
|
||||
// clusters from centralised ones, and turns on per-account inbound
|
||||
// listeners on each embedded client's netstack: every account that
|
||||
// registers a service exposes :80 + :443 inside its own WG tunnel,
|
||||
// scoped to that account's services only.
|
||||
Private bool
|
||||
// MaxDialTimeout caps the per-service backend dial timeout.
|
||||
// When the API sends a timeout, it is clamped to this value.
|
||||
// When the API sends no timeout, this value is used as the default.
|
||||
@@ -222,12 +248,16 @@ func (s *Server) NotifyStatus(ctx context.Context, accountID types.AccountID, se
|
||||
status = proto.ProxyStatus_PROXY_STATUS_ACTIVE
|
||||
}
|
||||
|
||||
_, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{
|
||||
req := &proto.SendStatusUpdateRequest{
|
||||
ServiceId: string(serviceID),
|
||||
AccountId: string(accountID),
|
||||
Status: status,
|
||||
CertificateIssued: false,
|
||||
})
|
||||
}
|
||||
if connected {
|
||||
req.InboundListener = s.inboundListenerProto(accountID)
|
||||
}
|
||||
_, err := s.mgmtClient.SendStatusUpdate(ctx, req)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -238,56 +268,68 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID types.Ac
|
||||
AccountId: string(accountID),
|
||||
Status: proto.ProxyStatus_PROXY_STATUS_ACTIVE,
|
||||
CertificateIssued: true,
|
||||
InboundListener: s.inboundListenerProto(accountID),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
s.initDefaults()
|
||||
s.routerReady = make(chan struct{})
|
||||
s.udpRelays = make(map[types.ServiceID]*udprelay.Relay)
|
||||
s.portRouters = make(map[uint16]*portRouter)
|
||||
s.svcPorts = make(map[types.ServiceID][]uint16)
|
||||
s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping)
|
||||
|
||||
exporter, err := prometheus.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create prometheus exporter: %w", err)
|
||||
// inboundListenerProto resolves the per-account inbound listener state for
|
||||
// the SendStatusUpdate payload. Returns nil when --private-inbound is off
|
||||
// or the account has no live listener so management treats the field as
|
||||
// absent.
|
||||
func (s *Server) inboundListenerProto(accountID types.AccountID) *proto.ProxyInboundListener {
|
||||
if s.inbound == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider := metric.NewMeterProvider(metric.WithReader(exporter))
|
||||
pkg := reflect.TypeOf(Server{}).PkgPath()
|
||||
meter := provider.Meter(pkg)
|
||||
|
||||
s.meter, err = proxymetrics.New(ctx, meter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create metrics: %w", err)
|
||||
info, ok := s.inbound.ListenerInfo(accountID)
|
||||
if !ok || info.TunnelIP == "" {
|
||||
return nil
|
||||
}
|
||||
return &proto.ProxyInboundListener{
|
||||
TunnelIp: info.TunnelIP,
|
||||
HttpsPort: uint32(info.HTTPSPort),
|
||||
HttpPort: uint32(info.HTTPPort),
|
||||
}
|
||||
}
|
||||
|
||||
mgmtConn, err := s.dialManagement()
|
||||
if err != nil {
|
||||
// ListenAndServe is the standalone entrypoint. It binds the listener, runs
|
||||
// the proxy until ctx is cancelled or a background goroutine fails, then
|
||||
// drains and stops. Library callers should prefer New + Start + Stop and
|
||||
// own their own shutdown signalling.
|
||||
func (s *Server) ListenAndServe(ctx context.Context, addr string) error {
|
||||
s.ListenAddr = addr
|
||||
if err := s.Start(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := mgmtConn.Close(); err != nil {
|
||||
s.Logger.Debugf("management connection close: %v", err)
|
||||
}
|
||||
}()
|
||||
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
|
||||
return s.waitAndStop(ctx)
|
||||
}
|
||||
|
||||
// Start brings the proxy up: dials management, configures TLS, binds the
|
||||
// main listener, and spawns the SNI router and HTTPS server goroutines. It
|
||||
// returns once the listener is bound; background errors are surfaced
|
||||
// through Stop's return value. Start is not safe to call twice.
|
||||
func (s *Server) Start(ctx context.Context) error {
|
||||
s.startMu.Lock()
|
||||
if s.started {
|
||||
s.startMu.Unlock()
|
||||
return errors.New("proxy already started")
|
||||
}
|
||||
s.started = true
|
||||
s.startMu.Unlock()
|
||||
|
||||
s.initLifecycleState()
|
||||
if err := s.initMetrics(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.initManagementClient(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
runCtx, runCancel := context.WithCancel(ctx)
|
||||
defer runCancel()
|
||||
s.runCancel = runCancel
|
||||
|
||||
// Initialize the netbird client, this is required to build peer connections
|
||||
// to proxy over.
|
||||
s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{
|
||||
MgmtAddr: s.ManagementAddress,
|
||||
WGPort: s.WireguardPort,
|
||||
PreSharedKey: s.PreSharedKey,
|
||||
}, s.Logger, s, s.mgmtClient)
|
||||
s.netbird.OnAddPeer = s.meter.RecordAddPeerDuration
|
||||
|
||||
// Create health checker before the mapping worker so it can track
|
||||
// management connectivity from the first stream connection.
|
||||
s.initNetBirdClient()
|
||||
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
||||
|
||||
s.crowdsecRegistry = crowdsec.NewRegistry(s.CrowdSecAPIURL, s.CrowdSecAPIKey, log.NewEntry(s.Logger))
|
||||
@@ -300,34 +342,25 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
|
||||
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
|
||||
s.initReverseProxy()
|
||||
|
||||
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize geolocation: %w", err)
|
||||
}
|
||||
s.geoRaw = geoLookup
|
||||
if geoLookup != nil {
|
||||
s.geo = geoLookup
|
||||
if err := s.initGeoLookup(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var startupOK bool
|
||||
startupOK := false
|
||||
defer func() {
|
||||
if startupOK {
|
||||
return
|
||||
}
|
||||
if s.geoRaw != nil {
|
||||
if err := s.geoRaw.Close(); err != nil {
|
||||
s.Logger.Debugf("close geolocation on startup failure: %v", err)
|
||||
if closeErr := s.geoRaw.Close(); closeErr != nil {
|
||||
s.Logger.Debugf("close geolocation on startup failure: %v", closeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Configure the authentication middleware with session validator for OIDC group checks.
|
||||
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient, s.geo)
|
||||
|
||||
// Configure Access logs to management server.
|
||||
s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
|
||||
|
||||
s.startDebugEndpoint()
|
||||
@@ -336,35 +369,21 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build the handler chain from inside out.
|
||||
handler := http.Handler(s.proxy)
|
||||
handler = s.auth.Protect(handler)
|
||||
handler = web.AssetHandler(handler)
|
||||
handler = s.accessLog.Middleware(handler)
|
||||
handler = s.meter.Middleware(handler)
|
||||
handler = s.hijackTracker.Middleware(handler)
|
||||
handler := s.buildHandlerChain()
|
||||
s.initPrivateInbound(handler, tlsConfig)
|
||||
|
||||
// Start a raw TCP listener; the SNI router peeks at ClientHello
|
||||
// and routes to either the HTTP handler or a TCP relay.
|
||||
lc := net.ListenConfig{}
|
||||
ln, err := lc.Listen(ctx, "tcp", addr)
|
||||
ln, err := s.bindMainListener(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", addr, err)
|
||||
return err
|
||||
}
|
||||
if s.ProxyProtocol {
|
||||
ln = s.wrapProxyProtocol(ln)
|
||||
}
|
||||
s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid
|
||||
|
||||
// Set up the SNI router for TCP/HTTP multiplexing on the main port.
|
||||
s.mainRouter = nbtcp.NewRouter(s.Logger, s.resolveDialFunc, ln.Addr())
|
||||
s.mainRouter.SetObserver(s.meter)
|
||||
s.mainRouter.SetAccessLogger(s.accessLog)
|
||||
close(s.routerReady)
|
||||
|
||||
// The HTTP server uses the chanListener fed by the SNI router.
|
||||
s.https = &http.Server{
|
||||
Addr: addr,
|
||||
Addr: s.ListenAddr,
|
||||
Handler: handler,
|
||||
TLSConfig: tlsConfig,
|
||||
ReadHeaderTimeout: httpReadHeaderTimeout,
|
||||
@@ -374,35 +393,200 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
|
||||
startupOK = true
|
||||
|
||||
httpsErr := make(chan error, 1)
|
||||
go func() {
|
||||
s.Logger.Debug("starting HTTPS server on SNI router HTTP channel")
|
||||
httpsErr <- s.https.ServeTLS(s.mainRouter.HTTPListener(), "", "")
|
||||
if serveErr := s.https.ServeTLS(s.mainRouter.HTTPListener(), "", ""); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
|
||||
s.recordRunErr(fmt.Errorf("https server: %w", serveErr))
|
||||
}
|
||||
}()
|
||||
|
||||
routerErr := make(chan error, 1)
|
||||
go func() {
|
||||
s.Logger.Debugf("starting SNI router on %s", addr)
|
||||
routerErr <- s.mainRouter.Serve(runCtx, ln)
|
||||
s.Logger.Debugf("starting SNI router on %s", s.ListenAddr)
|
||||
if serveErr := s.mainRouter.Serve(runCtx, ln); serveErr != nil {
|
||||
s.recordRunErr(fmt.Errorf("SNI router: %w", serveErr))
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop drains in-flight connections, shuts down all background services,
|
||||
// and releases resources. Idempotent; calling it before Start is a no-op.
|
||||
// Returns the first fatal error reported by a background goroutine, if
|
||||
// any. The provided ctx bounds the total wait time; once it is cancelled
|
||||
// Stop returns even if drain is still in flight.
|
||||
func (s *Server) Stop(ctx context.Context) error {
|
||||
s.stopOnce.Do(func() {
|
||||
s.startMu.Lock()
|
||||
started := s.started
|
||||
s.startMu.Unlock()
|
||||
if !started {
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
s.gracefulShutdown()
|
||||
if s.runCancel != nil {
|
||||
s.runCancel()
|
||||
}
|
||||
if s.mgmtConn != nil {
|
||||
if err := s.mgmtConn.Close(); err != nil {
|
||||
s.Logger.Debugf("management connection close: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
s.Logger.Warnf("proxy stop deadline exceeded: %v", ctx.Err())
|
||||
}
|
||||
})
|
||||
|
||||
s.startMu.Lock()
|
||||
defer s.startMu.Unlock()
|
||||
return s.runErr
|
||||
}
|
||||
|
||||
// waitAndStop blocks until ctx is cancelled or a background goroutine
|
||||
// reports a fatal error, then drains and stops. Used by ListenAndServe.
|
||||
func (s *Server) waitAndStop(ctx context.Context) error {
|
||||
select {
|
||||
case err := <-httpsErr:
|
||||
s.shutdownServices()
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("https server: %w", err)
|
||||
}
|
||||
return nil
|
||||
case err := <-routerErr:
|
||||
s.shutdownServices()
|
||||
if err != nil {
|
||||
return fmt.Errorf("SNI router: %w", err)
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
s.gracefulShutdown()
|
||||
return nil
|
||||
case <-s.runErrCh:
|
||||
}
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), shutdownDrainTimeout+shutdownServiceTimeout)
|
||||
defer cancel()
|
||||
return s.Stop(stopCtx)
|
||||
}
|
||||
|
||||
// recordRunErr stores the first fatal background error and signals
|
||||
// waitAndStop. Subsequent errors are logged at debug level so the first
|
||||
// cause is preserved.
|
||||
func (s *Server) recordRunErr(err error) {
|
||||
s.startMu.Lock()
|
||||
defer s.startMu.Unlock()
|
||||
if s.runErr != nil {
|
||||
s.Logger.Debugf("background error after first failure: %v", err)
|
||||
return
|
||||
}
|
||||
s.runErr = err
|
||||
if s.runErrCh != nil {
|
||||
close(s.runErrCh)
|
||||
}
|
||||
}
|
||||
|
||||
// initLifecycleState seeds the maps and channels Start needs to wire up
|
||||
// background goroutines. Called once at the top of Start.
|
||||
func (s *Server) initLifecycleState() {
|
||||
s.initDefaults()
|
||||
s.routerReady = make(chan struct{})
|
||||
s.runErrCh = make(chan struct{})
|
||||
s.udpRelays = make(map[types.ServiceID]*udprelay.Relay)
|
||||
s.portRouters = make(map[uint16]*portRouter)
|
||||
s.svcPorts = make(map[types.ServiceID][]uint16)
|
||||
s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping)
|
||||
}
|
||||
|
||||
// initMetrics builds the prometheus exporter and meter bundle.
|
||||
func (s *Server) initMetrics(ctx context.Context) error {
|
||||
exporter, err := prometheus.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create prometheus exporter: %w", err)
|
||||
}
|
||||
provider := metric.NewMeterProvider(metric.WithReader(exporter))
|
||||
pkg := reflect.TypeOf(Server{}).PkgPath()
|
||||
meter := provider.Meter(pkg)
|
||||
s.meter, err = proxymetrics.New(ctx, meter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create metrics: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// initManagementClient dials management and stashes the connection so
|
||||
// Stop can close it deterministically.
|
||||
func (s *Server) initManagementClient() error {
|
||||
conn, err := s.dialManagement()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.mgmtConn = conn
|
||||
s.mgmtClient = proto.NewProxyServiceClient(conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
// initNetBirdClient builds the multi-tenant embedded NetBird client used
|
||||
// for outbound RoundTripping and (when --private-inbound is on) per-account
|
||||
// inbound listeners.
|
||||
func (s *Server) initNetBirdClient() {
|
||||
s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{
|
||||
MgmtAddr: s.ManagementAddress,
|
||||
WGPort: s.WireguardPort,
|
||||
PreSharedKey: s.PreSharedKey,
|
||||
// On --private the embedded client serves per-account inbound
|
||||
// listeners and must apply management's ACL: keep BlockInbound off
|
||||
// so the engine creates the ACL manager. On the standalone proxy
|
||||
// the embedded client never accepts inbound, so block.
|
||||
BlockInbound: !s.Private,
|
||||
}, s.Logger, s, s.mgmtClient)
|
||||
}
|
||||
|
||||
// initReverseProxy builds the meter-instrumented reverse proxy. MultiTransport
|
||||
// routes targets opted into direct_upstream through the host's network stack
|
||||
// (stdlib transport); everything else falls through to the embedded NetBird
|
||||
// client. The split is needed so direct_upstream targets resolve DNS via the
|
||||
// proxy host's resolver instead of the tunnel's DNS.
|
||||
func (s *Server) initReverseProxy() {
|
||||
upstreamRT := roundtrip.NewMultiTransport(s.netbird, s.Logger)
|
||||
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(upstreamRT), s.ForwardedProto, s.TrustedProxies, s.Logger)
|
||||
}
|
||||
|
||||
// initGeoLookup configures the GeoLite2 lookup used for country-based
|
||||
// access restrictions and access-log enrichment.
|
||||
func (s *Server) initGeoLookup() error {
|
||||
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize geolocation: %w", err)
|
||||
}
|
||||
s.geoRaw = geoLookup
|
||||
if geoLookup != nil {
|
||||
s.geo = geoLookup
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildHandlerChain wires the request middlewares from inside out.
|
||||
func (s *Server) buildHandlerChain() http.Handler {
|
||||
handler := http.Handler(s.proxy)
|
||||
handler = s.auth.Protect(handler)
|
||||
handler = web.AssetHandler(handler)
|
||||
handler = s.accessLog.Middleware(handler)
|
||||
handler = s.meter.Middleware(handler)
|
||||
return s.hijackTracker.Middleware(handler)
|
||||
}
|
||||
|
||||
// bindMainListener binds the main TCP listener and wraps it with PROXY
|
||||
// protocol when configured.
|
||||
func (s *Server) bindMainListener(ctx context.Context) (net.Listener, error) {
|
||||
lc := net.ListenConfig{}
|
||||
ln, err := lc.Listen(ctx, "tcp", s.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen on %s: %w", s.ListenAddr, err)
|
||||
}
|
||||
if s.ProxyProtocol {
|
||||
ln = s.wrapProxyProtocol(ln)
|
||||
}
|
||||
s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"requested_addr": s.ListenAddr,
|
||||
"bound_addr": ln.Addr().String(),
|
||||
"private": s.Private,
|
||||
"proxy_protocol": s.ProxyProtocol,
|
||||
}).Info("proxy main listener bound")
|
||||
return ln, nil
|
||||
}
|
||||
|
||||
// initDefaults sets fallback values for optional Server fields.
|
||||
@@ -434,6 +618,9 @@ func (s *Server) startDebugEndpoint() {
|
||||
if s.acme != nil {
|
||||
debugHandler.SetCertStatus(s.acme)
|
||||
}
|
||||
if s.inbound != nil {
|
||||
debugHandler.SetInboundProvider(inboundDebugAdapter{mgr: s.inbound})
|
||||
}
|
||||
s.debug = &http.Server{
|
||||
Addr: debugAddr,
|
||||
Handler: debugHandler,
|
||||
@@ -447,16 +634,18 @@ func (s *Server) startDebugEndpoint() {
|
||||
}()
|
||||
}
|
||||
|
||||
// startHealthServer launches the health probe and metrics server.
|
||||
// startHealthServer launches the health probe and metrics server. Empty
|
||||
// HealthAddress disables the probe entirely (intended for library callers
|
||||
// that want to manage their own health surface).
|
||||
func (s *Server) startHealthServer() error {
|
||||
healthAddr := s.HealthAddress
|
||||
if healthAddr == "" {
|
||||
healthAddr = defaultHealthAddr
|
||||
if s.HealthAddress == "" {
|
||||
s.Logger.Debug("health probe disabled (empty HealthAddress)")
|
||||
return nil
|
||||
}
|
||||
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true}))
|
||||
healthListener, err := net.Listen("tcp", healthAddr)
|
||||
s.healthServer = health.NewServer(s.HealthAddress, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true}))
|
||||
healthListener, err := net.Listen("tcp", s.HealthAddress)
|
||||
if err != nil {
|
||||
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
|
||||
return fmt.Errorf("health probe server listen on %s: %w", s.HealthAddress, err)
|
||||
}
|
||||
go func() {
|
||||
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
@@ -507,8 +696,9 @@ func (s *Server) proxyProtocolPolicy(opts proxyproto.ConnPolicyOptions) (proxypr
|
||||
}
|
||||
|
||||
const (
|
||||
defaultHealthAddr = "localhost:8080"
|
||||
defaultDebugAddr = "localhost:8444"
|
||||
// defaultDebugAddr is the localhost-bound fallback for the debug endpoint
|
||||
// when DebugEndpointAddress is empty.
|
||||
defaultDebugAddr = "localhost:8444"
|
||||
|
||||
// proxyProtoHeaderTimeout is the deadline for reading the PROXY protocol
|
||||
// header after accepting a connection.
|
||||
@@ -661,8 +851,10 @@ func (s *Server) gracefulShutdown() {
|
||||
defer drainCancel()
|
||||
|
||||
s.Logger.Info("draining in-flight connections")
|
||||
if err := s.https.Shutdown(drainCtx); err != nil {
|
||||
s.Logger.Warnf("https server drain: %v", err)
|
||||
if s.https != nil {
|
||||
if err := s.https.Shutdown(drainCtx); err != nil {
|
||||
s.Logger.Warnf("https server drain: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle.
|
||||
@@ -809,6 +1001,18 @@ func (s *Server) resolveDialFunc(accountID types.AccountID) (types.DialContextFu
|
||||
return client.DialContext, nil
|
||||
}
|
||||
|
||||
// initPrivateInbound wires per-account inbound listeners when --private
|
||||
// is set. When the flag is off this is a no-op and the standalone proxy keeps
|
||||
// its byte-for-byte previous behaviour.
|
||||
func (s *Server) initPrivateInbound(handler http.Handler, tlsConfig *tls.Config) {
|
||||
if !s.Private {
|
||||
return
|
||||
}
|
||||
s.inbound = newInboundManager(s.Logger, handler, tlsConfig)
|
||||
s.netbird.SetClientLifecycle(s.inbound.onClientReady, s.inbound.onClientStop)
|
||||
s.Logger.Info("private inbound listeners enabled (per-account :80 + :443)")
|
||||
}
|
||||
|
||||
// notifyError reports a resource error back to management so it can be
|
||||
// surfaced to the user (e.g. port bind failure, dialer resolution error).
|
||||
func (s *Server) notifyError(ctx context.Context, mapping *proto.ProxyMapping, err error) {
|
||||
@@ -941,9 +1145,6 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
|
||||
// syncSupported tracks whether management supports SyncMappings.
|
||||
// Starts true; set to false on first Unimplemented error.
|
||||
syncSupported := true
|
||||
initialSyncDone := false
|
||||
|
||||
operation := func() error {
|
||||
@@ -955,25 +1156,41 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
|
||||
var streamErr error
|
||||
if syncSupported {
|
||||
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone)
|
||||
if isSyncUnimplemented(streamErr) {
|
||||
syncSupported = false
|
||||
s.Logger.Info("management does not support SyncMappings, falling back to GetMappingUpdate")
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
|
||||
}
|
||||
} else {
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
|
||||
supportsCrowdSec := s.crowdsecRegistry.Available()
|
||||
privateCapability := s.Private
|
||||
// Always true: this build enforces ProxyMapping.private via the auth middleware.
|
||||
supportsPrivateService := true
|
||||
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: s.ID,
|
||||
Version: s.Version,
|
||||
StartedAt: timestamppb.New(s.startTime),
|
||||
Address: s.ProxyURL,
|
||||
Capabilities: &proto.ProxyCapabilities{
|
||||
SupportsCustomPorts: &s.SupportsCustomPorts,
|
||||
RequireSubdomain: &s.RequireSubdomain,
|
||||
SupportsCrowdsec: &supportsCrowdSec,
|
||||
Private: &privateCapability,
|
||||
SupportsPrivateService: &supportsPrivateService,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create mapping stream: %w", err)
|
||||
}
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(true)
|
||||
}
|
||||
s.Logger.Debug("management mapping stream established")
|
||||
|
||||
// Stream established — reset backoff so the next failure retries quickly.
|
||||
bo.Reset()
|
||||
|
||||
streamErr := s.handleMappingStream(ctx, mappingClient, &initialSyncDone, time.Now())
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
|
||||
// Stream established — reset backoff so the next failure retries quickly.
|
||||
bo.Reset()
|
||||
|
||||
if streamErr == nil {
|
||||
return fmt.Errorf("stream closed by server")
|
||||
}
|
||||
@@ -990,187 +1207,56 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) proxyCapabilities() *proto.ProxyCapabilities {
|
||||
supportsCrowdSec := s.crowdsecRegistry.Available()
|
||||
return &proto.ProxyCapabilities{
|
||||
SupportsCustomPorts: &s.SupportsCustomPorts,
|
||||
RequireSubdomain: &s.RequireSubdomain,
|
||||
SupportsCrowdsec: &supportsCrowdSec,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
|
||||
connectTime := time.Now()
|
||||
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: s.ID,
|
||||
Version: s.Version,
|
||||
StartedAt: timestamppb.New(s.startTime),
|
||||
Address: s.ProxyURL,
|
||||
Capabilities: s.proxyCapabilities(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create mapping stream: %w", err)
|
||||
}
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(true)
|
||||
}
|
||||
s.Logger.Debug("management mapping stream established (GetMappingUpdate)")
|
||||
|
||||
return s.handleMappingStream(ctx, mappingClient, initialSyncDone, connectTime)
|
||||
}
|
||||
|
||||
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
|
||||
connectTime := time.Now()
|
||||
stream, err := client.SyncMappings(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create sync stream: %w", err)
|
||||
}
|
||||
|
||||
// Send init message.
|
||||
if err := stream.Send(&proto.SyncMappingsRequest{
|
||||
Msg: &proto.SyncMappingsRequest_Init{
|
||||
Init: &proto.SyncMappingsInit{
|
||||
ProxyId: s.ID,
|
||||
Version: s.Version,
|
||||
StartedAt: timestamppb.New(s.startTime),
|
||||
Address: s.ProxyURL,
|
||||
Capabilities: s.proxyCapabilities(),
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send sync init: %w", err)
|
||||
}
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(true)
|
||||
}
|
||||
s.Logger.Debug("management mapping stream established (SyncMappings)")
|
||||
|
||||
return s.handleSyncMappingsStream(ctx, stream, initialSyncDone, connectTime)
|
||||
}
|
||||
|
||||
func isSyncUnimplemented(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
return ok && st.Code() == codes.Unimplemented
|
||||
}
|
||||
|
||||
func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.ProxyService_SyncMappingsClient, initialSyncDone *bool, connectTime time.Time) error {
|
||||
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool, _ time.Time) error {
|
||||
select {
|
||||
case <-s.routerReady:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
tracker := s.newSnapshotTracker(initialSyncDone, connectTime)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
msg, err := stream.Recv()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
return nil
|
||||
case err != nil:
|
||||
return fmt.Errorf("receive msg: %w", err)
|
||||
}
|
||||
|
||||
batchStart := time.Now()
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
|
||||
|
||||
if err := stream.Send(&proto.SyncMappingsRequest{
|
||||
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send ack: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool, connectTime time.Time) error {
|
||||
select {
|
||||
case <-s.routerReady:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
var snapshotIDs map[types.ServiceID]struct{}
|
||||
if !*initialSyncDone {
|
||||
snapshotIDs = make(map[types.ServiceID]struct{})
|
||||
}
|
||||
|
||||
tracker := s.newSnapshotTracker(initialSyncDone, connectTime)
|
||||
|
||||
for {
|
||||
// Check for context completion to gracefully shutdown.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Shutting down.
|
||||
return ctx.Err()
|
||||
default:
|
||||
msg, err := mappingClient.Recv()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
// Mapping connection gracefully terminated by server.
|
||||
return nil
|
||||
case err != nil:
|
||||
// Something has gone horribly wrong, return and hope the parent retries the connection.
|
||||
return fmt.Errorf("receive msg: %w", err)
|
||||
}
|
||||
|
||||
batchStart := time.Now()
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
|
||||
|
||||
if !*initialSyncDone {
|
||||
for _, m := range msg.GetMapping() {
|
||||
snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
|
||||
}
|
||||
if msg.GetInitialSyncComplete() {
|
||||
s.reconcileSnapshot(ctx, snapshotIDs)
|
||||
snapshotIDs = nil
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetInitialSyncComplete()
|
||||
}
|
||||
*initialSyncDone = true
|
||||
s.Logger.Info("Initial mapping sync complete")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// snapshotTracker accumulates service IDs during the initial snapshot and
|
||||
// finalises sync state when the complete flag arrives.
|
||||
type snapshotTracker struct {
|
||||
done *bool
|
||||
connectTime time.Time
|
||||
snapshotIDs map[types.ServiceID]struct{}
|
||||
}
|
||||
|
||||
func (s *Server) newSnapshotTracker(done *bool, connectTime time.Time) *snapshotTracker {
|
||||
var ids map[types.ServiceID]struct{}
|
||||
if !*done {
|
||||
ids = make(map[types.ServiceID]struct{})
|
||||
}
|
||||
return &snapshotTracker{done: done, connectTime: connectTime, snapshotIDs: ids}
|
||||
}
|
||||
|
||||
func (t *snapshotTracker) recordBatch(ctx context.Context, s *Server, mappings []*proto.ProxyMapping, syncComplete bool, batchStart time.Time) {
|
||||
if *t.done {
|
||||
return
|
||||
}
|
||||
|
||||
if s.meter != nil {
|
||||
s.meter.RecordSnapshotBatchDuration(time.Since(batchStart))
|
||||
}
|
||||
|
||||
for _, m := range mappings {
|
||||
t.snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
|
||||
}
|
||||
|
||||
if !syncComplete {
|
||||
return
|
||||
}
|
||||
|
||||
s.reconcileSnapshot(ctx, t.snapshotIDs)
|
||||
t.snapshotIDs = nil
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetInitialSyncComplete()
|
||||
}
|
||||
*t.done = true
|
||||
if s.meter != nil {
|
||||
s.meter.RecordSnapshotSyncDuration(time.Since(t.connectTime))
|
||||
}
|
||||
s.Logger.Info("Initial mapping sync complete")
|
||||
}
|
||||
|
||||
// reconcileSnapshot removes local mappings that are absent from the snapshot.
|
||||
// This ensures services deleted while the proxy was disconnected get cleaned up.
|
||||
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
|
||||
@@ -1192,17 +1278,29 @@ func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.Se
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
s.ensurePeers(ctx, mappings)
|
||||
// mappingJSONMarshal dumps mappings on one line with zero-value fields visible for debug logs.
|
||||
var mappingJSONMarshal = protojson.MarshalOptions{
|
||||
Multiline: false,
|
||||
EmitUnpopulated: true,
|
||||
UseProtoNames: true,
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
// The full proto dump carries auth_token and header-auth values; gate on debug.
|
||||
debug := s.Logger != nil && s.Logger.IsLevelEnabled(log.DebugLevel)
|
||||
for _, mapping := range mappings {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"type": mapping.GetType(),
|
||||
"domain": mapping.GetDomain(),
|
||||
"mode": mapping.GetMode(),
|
||||
"port": mapping.GetListenPort(),
|
||||
"id": mapping.GetId(),
|
||||
}).Debug("Processing mapping update")
|
||||
if debug {
|
||||
raw, err := mappingJSONMarshal.Marshal(mapping)
|
||||
if err != nil {
|
||||
raw = []byte(fmt.Sprintf("<marshal error: %v>", err))
|
||||
}
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"type": mapping.GetType(),
|
||||
"domain": mapping.GetDomain(),
|
||||
"id": mapping.GetId(),
|
||||
"mapping": string(raw),
|
||||
}).Debug("Processing mapping update")
|
||||
}
|
||||
switch mapping.GetType() {
|
||||
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
|
||||
if err := s.addMapping(ctx, mapping); err != nil {
|
||||
@@ -1228,60 +1326,6 @@ func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMap
|
||||
}
|
||||
}
|
||||
|
||||
// ensurePeers pre-creates NetBird peers for all unique accounts referenced by
|
||||
// CREATED mappings. Peers for different accounts are created concurrently,
|
||||
// which avoids serializing N×100ms gRPC round-trips during large initial syncs.
|
||||
func (s *Server) ensurePeers(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
// Collect one representative mapping per account that needs a new peer.
|
||||
type peerReq struct {
|
||||
accountID types.AccountID
|
||||
svcKey roundtrip.ServiceKey
|
||||
authToken string
|
||||
svcID types.ServiceID
|
||||
}
|
||||
seen := make(map[types.AccountID]struct{})
|
||||
var reqs []peerReq
|
||||
for _, m := range mappings {
|
||||
if m.GetType() != proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED {
|
||||
continue
|
||||
}
|
||||
accountID := types.AccountID(m.GetAccountId())
|
||||
if _, ok := seen[accountID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[accountID] = struct{}{}
|
||||
if s.netbird.HasClient(accountID) {
|
||||
continue
|
||||
}
|
||||
reqs = append(reqs, peerReq{
|
||||
accountID: accountID,
|
||||
svcKey: s.serviceKeyForMapping(m),
|
||||
authToken: m.GetAuthToken(),
|
||||
svcID: types.ServiceID(m.GetId()),
|
||||
})
|
||||
}
|
||||
|
||||
if len(reqs) <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(reqs))
|
||||
for _, r := range reqs {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.netbird.AddPeer(ctx, r.accountID, r.svcKey, r.authToken, r.svcID); err != nil {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"account_id": r.accountID,
|
||||
"service_id": r.svcID,
|
||||
"error": err,
|
||||
}).Warn("failed to pre-create peer for account")
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// addMapping registers a service mapping and starts the appropriate relay or routes.
|
||||
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
||||
accountID := types.AccountID(mapping.GetAccountId())
|
||||
@@ -1353,12 +1397,16 @@ func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMappi
|
||||
if s.acme != nil {
|
||||
wildcardHit = s.acme.AddDomain(d, accountID, svcID)
|
||||
}
|
||||
s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{
|
||||
httpRoute := nbtcp.Route{
|
||||
Type: nbtcp.RouteHTTP,
|
||||
AccountID: accountID,
|
||||
ServiceID: svcID,
|
||||
Domain: mapping.GetDomain(),
|
||||
})
|
||||
}
|
||||
s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), httpRoute)
|
||||
if s.inbound != nil {
|
||||
s.inbound.AddRoute(accountID, nbtcp.SNIHost(mapping.GetDomain()), httpRoute)
|
||||
}
|
||||
if err := s.updateMapping(ctx, mapping); err != nil {
|
||||
return fmt.Errorf("update mapping for domain %q: %w", d, err)
|
||||
}
|
||||
@@ -1718,7 +1766,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
|
||||
|
||||
maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second
|
||||
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions); err != nil {
|
||||
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions, mapping.GetPrivate()); err != nil {
|
||||
return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err)
|
||||
}
|
||||
m := s.protoToMapping(ctx, mapping)
|
||||
@@ -1774,6 +1822,9 @@ func (s *Server) cleanupMappingRoutes(mapping *proto.ProxyMapping) {
|
||||
}
|
||||
// Remove SNI route from the main router (covers both HTTP and main-port TLS).
|
||||
s.mainRouter.RemoveRoute(nbtcp.SNIHost(host), svcID)
|
||||
if s.inbound != nil {
|
||||
s.inbound.RemoveRoute(types.AccountID(mapping.GetAccountId()), nbtcp.SNIHost(host), svcID)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract and delete tracked custom-port entries atomically.
|
||||
@@ -1861,6 +1912,7 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
|
||||
if d := opts.GetRequestTimeout(); d != nil {
|
||||
pt.RequestTimeout = d.AsDuration()
|
||||
}
|
||||
pt.DirectUpstream = opts.GetDirectUpstream()
|
||||
}
|
||||
pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout)
|
||||
paths[pathMapping.GetPath()] = pt
|
||||
|
||||
Reference in New Issue
Block a user