feat(private-service): expose NetBird-only services over tunnel peers

Adds a new "private" service mode for the reverse proxy: services
reachable exclusively over the embedded WireGuard tunnel, gated by
per-peer group membership instead of operator auth schemes.

Wire contract
- ProxyMapping.private (field 13): the proxy MUST call
  ValidateTunnelPeer and fail closed; operator schemes are bypassed.
- ProxyCapabilities.private (4) + supports_private_service (5):
  capability gate. Management never streams private mappings to
  proxies that don't claim the capability; the broadcast path applies
  the same filter via filterMappingsForProxy.
- ValidateTunnelPeer RPC: resolves an inbound tunnel IP to a peer,
  checks the peer's groups against service.AccessGroups, and mints
  a session JWT on success. checkPeerGroupAccess fails closed when
  a private service has empty AccessGroups.
- ValidateSession/ValidateTunnelPeer responses now carry
  peer_group_ids + peer_group_names so the proxy can authorise
  policy-aware middlewares without an extra management round-trip.
- ProxyInboundListener + SendStatusUpdate.inbound_listener: per-account
  inbound listener state surfaced to dashboards.
- PathTargetOptions.direct_upstream (11): bypass the embedded NetBird
  client and dial the target via the proxy host's network stack for
  upstreams reachable without WireGuard.

Data model
- Service.Private (bool) + Service.AccessGroups ([]string, JSON-
  serialised). Validate() rejects bearer auth on private services.
  Copy() deep-copies AccessGroups. pgx getServices loads the columns.
- DomainConfig.Private threaded into the proxy auth middleware.
  Request handler routes private services through forwardWithTunnelPeer
  and returns 403 on validation failure.
- Account-level SynthesizePrivateServiceZones (synthetic DNS) and
  injectPrivateServicePolicies (synthetic ACL) gate on
  len(svc.AccessGroups) > 0.

Proxy
- /netbird proxy --private (embedded mode) flag; Config.Private in
  proxy/lifecycle.go.
- Per-account inbound listener (proxy/inbound.go) binding HTTP/HTTPS
  on the embedded NetBird client's WireGuard tunnel netstack.
- proxy/internal/auth/tunnel_cache: ValidateTunnelPeer response cache
  with single-flight de-duplication and per-account eviction.
- Local peerstore short-circuit: when the inbound IP isn't in the
  account roster, deny fast without an RPC.
- proxy/server.go reports SupportsPrivateService=true and redacts the
  full ProxyMapping JSON from info logs (auth_token + header-auth
  hashed values now only at debug level).

Identity forwarding
- ValidateSessionJWT returns user_id, email, method, groups,
  group_names. sessionkey.Claims carries Email + Groups + GroupNames
  so the proxy can stamp identity onto upstream requests without an
  extra management round-trip on every cookie-bearing request.
- CapturedData carries userEmail / userGroups / userGroupNames; the
  proxy stamps X-NetBird-User and X-NetBird-Groups on r.Out from the
  authenticated identity (strips client-supplied values first to
  prevent spoofing).
- AccessLog.UserGroups: access-log enrichment captures the user's
  group memberships at write time so the dashboard can render group
  context without reverse-resolving stale memberships.

OpenAPI/dashboard surface
- ReverseProxyService gains private + access_groups; ReverseProxyCluster
  gains private + supports_private. ReverseProxyTarget target_type
  enum gains "cluster". ServiceTargetOptions gains direct_upstream.
  ProxyAccessLog gains user_groups.
This commit is contained in:
mlsmaycon
2026-05-20 21:39:22 +02:00
parent 37052fd5bc
commit 167ee08e14
72 changed files with 6584 additions and 2586 deletions

View File

@@ -32,11 +32,10 @@ import (
"go.opentelemetry.io/otel/sdk/metric"
"golang.org/x/exp/maps"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
grpcstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/proxy/internal/accesslog"
@@ -114,9 +113,28 @@ type Server struct {
// The mapping worker waits on this before processing updates.
routerReady chan struct{}
// inbound, when non-nil, manages per-account inbound listeners. Set by
// initPrivateInbound only when Private is true so the standalone
// proxy keeps its zero-overhead default path.
inbound *inboundManager
// Lifecycle state — populated by Start, consumed by Stop. The fields
// stay zero on a fresh Server until Start runs so direct struct
// construction (`&Server{...}`) used by tests still works.
runCancel context.CancelFunc
mgmtConn *grpc.ClientConn
runErr error
runErrCh chan struct{}
startMu sync.Mutex
started bool
stopOnce sync.Once
// Mostly used for debugging on management.
startTime time.Time
// ListenAddr is the address the main TCP listener binds. Populated by
// New from Config or by ListenAndServe from its addr argument.
ListenAddr string
ID string
Logger *log.Logger
Version string
@@ -177,6 +195,14 @@ type Server struct {
// in front of this proxy's cluster domain. When true, accounts cannot
// create services on the bare cluster domain.
RequireSubdomain bool
// Private flags this proxy as embedded in a netbird client and serving
// exclusively over the WireGuard tunnel (i.e. `netbird proxy`). Reported
// upstream as a capability so dashboards can distinguish per-peer
// clusters from centralised ones, and turns on per-account inbound
// listeners on each embedded client's netstack: every account that
// registers a service exposes :80 + :443 inside its own WG tunnel,
// scoped to that account's services only.
Private bool
// MaxDialTimeout caps the per-service backend dial timeout.
// When the API sends a timeout, it is clamped to this value.
// When the API sends no timeout, this value is used as the default.
@@ -222,12 +248,16 @@ func (s *Server) NotifyStatus(ctx context.Context, accountID types.AccountID, se
status = proto.ProxyStatus_PROXY_STATUS_ACTIVE
}
_, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{
req := &proto.SendStatusUpdateRequest{
ServiceId: string(serviceID),
AccountId: string(accountID),
Status: status,
CertificateIssued: false,
})
}
if connected {
req.InboundListener = s.inboundListenerProto(accountID)
}
_, err := s.mgmtClient.SendStatusUpdate(ctx, req)
return err
}
@@ -238,56 +268,68 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID types.Ac
AccountId: string(accountID),
Status: proto.ProxyStatus_PROXY_STATUS_ACTIVE,
CertificateIssued: true,
InboundListener: s.inboundListenerProto(accountID),
})
return err
}
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.initDefaults()
s.routerReady = make(chan struct{})
s.udpRelays = make(map[types.ServiceID]*udprelay.Relay)
s.portRouters = make(map[uint16]*portRouter)
s.svcPorts = make(map[types.ServiceID][]uint16)
s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping)
exporter, err := prometheus.New()
if err != nil {
return fmt.Errorf("create prometheus exporter: %w", err)
// inboundListenerProto resolves the per-account inbound listener state for
// the SendStatusUpdate payload. Returns nil when --private-inbound is off
// or the account has no live listener so management treats the field as
// absent.
func (s *Server) inboundListenerProto(accountID types.AccountID) *proto.ProxyInboundListener {
if s.inbound == nil {
return nil
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
pkg := reflect.TypeOf(Server{}).PkgPath()
meter := provider.Meter(pkg)
s.meter, err = proxymetrics.New(ctx, meter)
if err != nil {
return fmt.Errorf("create metrics: %w", err)
info, ok := s.inbound.ListenerInfo(accountID)
if !ok || info.TunnelIP == "" {
return nil
}
return &proto.ProxyInboundListener{
TunnelIp: info.TunnelIP,
HttpsPort: uint32(info.HTTPSPort),
HttpPort: uint32(info.HTTPPort),
}
}
mgmtConn, err := s.dialManagement()
if err != nil {
// ListenAndServe is the standalone entrypoint. It binds the listener, runs
// the proxy until ctx is cancelled or a background goroutine fails, then
// drains and stops. Library callers should prefer New + Start + Stop and
// own their own shutdown signalling.
func (s *Server) ListenAndServe(ctx context.Context, addr string) error {
s.ListenAddr = addr
if err := s.Start(ctx); err != nil {
return err
}
defer func() {
if err := mgmtConn.Close(); err != nil {
s.Logger.Debugf("management connection close: %v", err)
}
}()
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
return s.waitAndStop(ctx)
}
// Start brings the proxy up: dials management, configures TLS, binds the
// main listener, and spawns the SNI router and HTTPS server goroutines. It
// returns once the listener is bound; background errors are surfaced
// through Stop's return value. Start is not safe to call twice.
func (s *Server) Start(ctx context.Context) error {
s.startMu.Lock()
if s.started {
s.startMu.Unlock()
return errors.New("proxy already started")
}
s.started = true
s.startMu.Unlock()
s.initLifecycleState()
if err := s.initMetrics(ctx); err != nil {
return err
}
if err := s.initManagementClient(); err != nil {
return err
}
runCtx, runCancel := context.WithCancel(ctx)
defer runCancel()
s.runCancel = runCancel
// Initialize the netbird client, this is required to build peer connections
// to proxy over.
s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{
MgmtAddr: s.ManagementAddress,
WGPort: s.WireguardPort,
PreSharedKey: s.PreSharedKey,
}, s.Logger, s, s.mgmtClient)
s.netbird.OnAddPeer = s.meter.RecordAddPeerDuration
// Create health checker before the mapping worker so it can track
// management connectivity from the first stream connection.
s.initNetBirdClient()
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
s.crowdsecRegistry = crowdsec.NewRegistry(s.CrowdSecAPIURL, s.CrowdSecAPIKey, log.NewEntry(s.Logger))
@@ -300,34 +342,25 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
return err
}
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
s.initReverseProxy()
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
if err != nil {
return fmt.Errorf("initialize geolocation: %w", err)
}
s.geoRaw = geoLookup
if geoLookup != nil {
s.geo = geoLookup
if err := s.initGeoLookup(); err != nil {
return err
}
var startupOK bool
startupOK := false
defer func() {
if startupOK {
return
}
if s.geoRaw != nil {
if err := s.geoRaw.Close(); err != nil {
s.Logger.Debugf("close geolocation on startup failure: %v", err)
if closeErr := s.geoRaw.Close(); closeErr != nil {
s.Logger.Debugf("close geolocation on startup failure: %v", closeErr)
}
}
}()
// Configure the authentication middleware with session validator for OIDC group checks.
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient, s.geo)
// Configure Access logs to management server.
s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
s.startDebugEndpoint()
@@ -336,35 +369,21 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
return err
}
// Build the handler chain from inside out.
handler := http.Handler(s.proxy)
handler = s.auth.Protect(handler)
handler = web.AssetHandler(handler)
handler = s.accessLog.Middleware(handler)
handler = s.meter.Middleware(handler)
handler = s.hijackTracker.Middleware(handler)
handler := s.buildHandlerChain()
s.initPrivateInbound(handler, tlsConfig)
// Start a raw TCP listener; the SNI router peeks at ClientHello
// and routes to either the HTTP handler or a TCP relay.
lc := net.ListenConfig{}
ln, err := lc.Listen(ctx, "tcp", addr)
ln, err := s.bindMainListener(ctx)
if err != nil {
return fmt.Errorf("listen on %s: %w", addr, err)
return err
}
if s.ProxyProtocol {
ln = s.wrapProxyProtocol(ln)
}
s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid
// Set up the SNI router for TCP/HTTP multiplexing on the main port.
s.mainRouter = nbtcp.NewRouter(s.Logger, s.resolveDialFunc, ln.Addr())
s.mainRouter.SetObserver(s.meter)
s.mainRouter.SetAccessLogger(s.accessLog)
close(s.routerReady)
// The HTTP server uses the chanListener fed by the SNI router.
s.https = &http.Server{
Addr: addr,
Addr: s.ListenAddr,
Handler: handler,
TLSConfig: tlsConfig,
ReadHeaderTimeout: httpReadHeaderTimeout,
@@ -374,35 +393,200 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
startupOK = true
httpsErr := make(chan error, 1)
go func() {
s.Logger.Debug("starting HTTPS server on SNI router HTTP channel")
httpsErr <- s.https.ServeTLS(s.mainRouter.HTTPListener(), "", "")
if serveErr := s.https.ServeTLS(s.mainRouter.HTTPListener(), "", ""); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
s.recordRunErr(fmt.Errorf("https server: %w", serveErr))
}
}()
routerErr := make(chan error, 1)
go func() {
s.Logger.Debugf("starting SNI router on %s", addr)
routerErr <- s.mainRouter.Serve(runCtx, ln)
s.Logger.Debugf("starting SNI router on %s", s.ListenAddr)
if serveErr := s.mainRouter.Serve(runCtx, ln); serveErr != nil {
s.recordRunErr(fmt.Errorf("SNI router: %w", serveErr))
}
}()
return nil
}
// Stop drains in-flight connections, shuts down all background services,
// and releases resources. Idempotent; calling it before Start is a no-op.
// Returns the first fatal error reported by a background goroutine, if
// any. The provided ctx bounds the total wait time; once it is cancelled
// Stop returns even if drain is still in flight.
func (s *Server) Stop(ctx context.Context) error {
s.stopOnce.Do(func() {
s.startMu.Lock()
started := s.started
s.startMu.Unlock()
if !started {
return
}
done := make(chan struct{})
go func() {
defer close(done)
s.gracefulShutdown()
if s.runCancel != nil {
s.runCancel()
}
if s.mgmtConn != nil {
if err := s.mgmtConn.Close(); err != nil {
s.Logger.Debugf("management connection close: %v", err)
}
}
}()
select {
case <-done:
case <-ctx.Done():
s.Logger.Warnf("proxy stop deadline exceeded: %v", ctx.Err())
}
})
s.startMu.Lock()
defer s.startMu.Unlock()
return s.runErr
}
// waitAndStop blocks until ctx is cancelled or a background goroutine
// reports a fatal error, then drains and stops. Used by ListenAndServe.
func (s *Server) waitAndStop(ctx context.Context) error {
select {
case err := <-httpsErr:
s.shutdownServices()
if !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("https server: %w", err)
}
return nil
case err := <-routerErr:
s.shutdownServices()
if err != nil {
return fmt.Errorf("SNI router: %w", err)
}
return nil
case <-ctx.Done():
s.gracefulShutdown()
return nil
case <-s.runErrCh:
}
stopCtx, cancel := context.WithTimeout(context.Background(), shutdownDrainTimeout+shutdownServiceTimeout)
defer cancel()
return s.Stop(stopCtx)
}
// recordRunErr stores the first fatal background error and signals
// waitAndStop. Subsequent errors are logged at debug level so the first
// cause is preserved.
func (s *Server) recordRunErr(err error) {
s.startMu.Lock()
defer s.startMu.Unlock()
if s.runErr != nil {
s.Logger.Debugf("background error after first failure: %v", err)
return
}
s.runErr = err
if s.runErrCh != nil {
close(s.runErrCh)
}
}
// initLifecycleState seeds the maps and channels Start needs to wire up
// background goroutines. Called once at the top of Start.
func (s *Server) initLifecycleState() {
s.initDefaults()
s.routerReady = make(chan struct{})
s.runErrCh = make(chan struct{})
s.udpRelays = make(map[types.ServiceID]*udprelay.Relay)
s.portRouters = make(map[uint16]*portRouter)
s.svcPorts = make(map[types.ServiceID][]uint16)
s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping)
}
// initMetrics builds the prometheus exporter and meter bundle.
func (s *Server) initMetrics(ctx context.Context) error {
exporter, err := prometheus.New()
if err != nil {
return fmt.Errorf("create prometheus exporter: %w", err)
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
pkg := reflect.TypeOf(Server{}).PkgPath()
meter := provider.Meter(pkg)
s.meter, err = proxymetrics.New(ctx, meter)
if err != nil {
return fmt.Errorf("create metrics: %w", err)
}
return nil
}
// initManagementClient dials management and stashes the connection so
// Stop can close it deterministically.
func (s *Server) initManagementClient() error {
conn, err := s.dialManagement()
if err != nil {
return err
}
s.mgmtConn = conn
s.mgmtClient = proto.NewProxyServiceClient(conn)
return nil
}
// initNetBirdClient builds the multi-tenant embedded NetBird client used
// for outbound RoundTripping and (when --private-inbound is on) per-account
// inbound listeners.
func (s *Server) initNetBirdClient() {
s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{
MgmtAddr: s.ManagementAddress,
WGPort: s.WireguardPort,
PreSharedKey: s.PreSharedKey,
// On --private the embedded client serves per-account inbound
// listeners and must apply management's ACL: keep BlockInbound off
// so the engine creates the ACL manager. On the standalone proxy
// the embedded client never accepts inbound, so block.
BlockInbound: !s.Private,
}, s.Logger, s, s.mgmtClient)
}
// initReverseProxy builds the meter-instrumented reverse proxy. MultiTransport
// routes targets opted into direct_upstream through the host's network stack
// (stdlib transport); everything else falls through to the embedded NetBird
// client. The split is needed so direct_upstream targets resolve DNS via the
// proxy host's resolver instead of the tunnel's DNS.
func (s *Server) initReverseProxy() {
upstreamRT := roundtrip.NewMultiTransport(s.netbird, s.Logger)
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(upstreamRT), s.ForwardedProto, s.TrustedProxies, s.Logger)
}
// initGeoLookup configures the GeoLite2 lookup used for country-based
// access restrictions and access-log enrichment.
func (s *Server) initGeoLookup() error {
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
if err != nil {
return fmt.Errorf("initialize geolocation: %w", err)
}
s.geoRaw = geoLookup
if geoLookup != nil {
s.geo = geoLookup
}
return nil
}
// buildHandlerChain wires the request middlewares from inside out.
func (s *Server) buildHandlerChain() http.Handler {
handler := http.Handler(s.proxy)
handler = s.auth.Protect(handler)
handler = web.AssetHandler(handler)
handler = s.accessLog.Middleware(handler)
handler = s.meter.Middleware(handler)
return s.hijackTracker.Middleware(handler)
}
// bindMainListener binds the main TCP listener and wraps it with PROXY
// protocol when configured.
func (s *Server) bindMainListener(ctx context.Context) (net.Listener, error) {
lc := net.ListenConfig{}
ln, err := lc.Listen(ctx, "tcp", s.ListenAddr)
if err != nil {
return nil, fmt.Errorf("listen on %s: %w", s.ListenAddr, err)
}
if s.ProxyProtocol {
ln = s.wrapProxyProtocol(ln)
}
s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid
s.Logger.WithFields(log.Fields{
"requested_addr": s.ListenAddr,
"bound_addr": ln.Addr().String(),
"private": s.Private,
"proxy_protocol": s.ProxyProtocol,
}).Info("proxy main listener bound")
return ln, nil
}
// initDefaults sets fallback values for optional Server fields.
@@ -434,6 +618,9 @@ func (s *Server) startDebugEndpoint() {
if s.acme != nil {
debugHandler.SetCertStatus(s.acme)
}
if s.inbound != nil {
debugHandler.SetInboundProvider(inboundDebugAdapter{mgr: s.inbound})
}
s.debug = &http.Server{
Addr: debugAddr,
Handler: debugHandler,
@@ -447,16 +634,18 @@ func (s *Server) startDebugEndpoint() {
}()
}
// startHealthServer launches the health probe and metrics server.
// startHealthServer launches the health probe and metrics server. Empty
// HealthAddress disables the probe entirely (intended for library callers
// that want to manage their own health surface).
func (s *Server) startHealthServer() error {
healthAddr := s.HealthAddress
if healthAddr == "" {
healthAddr = defaultHealthAddr
if s.HealthAddress == "" {
s.Logger.Debug("health probe disabled (empty HealthAddress)")
return nil
}
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true}))
healthListener, err := net.Listen("tcp", healthAddr)
s.healthServer = health.NewServer(s.HealthAddress, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true}))
healthListener, err := net.Listen("tcp", s.HealthAddress)
if err != nil {
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
return fmt.Errorf("health probe server listen on %s: %w", s.HealthAddress, err)
}
go func() {
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
@@ -507,8 +696,9 @@ func (s *Server) proxyProtocolPolicy(opts proxyproto.ConnPolicyOptions) (proxypr
}
const (
defaultHealthAddr = "localhost:8080"
defaultDebugAddr = "localhost:8444"
// defaultDebugAddr is the localhost-bound fallback for the debug endpoint
// when DebugEndpointAddress is empty.
defaultDebugAddr = "localhost:8444"
// proxyProtoHeaderTimeout is the deadline for reading the PROXY protocol
// header after accepting a connection.
@@ -661,8 +851,10 @@ func (s *Server) gracefulShutdown() {
defer drainCancel()
s.Logger.Info("draining in-flight connections")
if err := s.https.Shutdown(drainCtx); err != nil {
s.Logger.Warnf("https server drain: %v", err)
if s.https != nil {
if err := s.https.Shutdown(drainCtx); err != nil {
s.Logger.Warnf("https server drain: %v", err)
}
}
// Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle.
@@ -809,6 +1001,18 @@ func (s *Server) resolveDialFunc(accountID types.AccountID) (types.DialContextFu
return client.DialContext, nil
}
// initPrivateInbound wires per-account inbound listeners when --private
// is set. When the flag is off this is a no-op and the standalone proxy keeps
// its byte-for-byte previous behaviour.
func (s *Server) initPrivateInbound(handler http.Handler, tlsConfig *tls.Config) {
if !s.Private {
return
}
s.inbound = newInboundManager(s.Logger, handler, tlsConfig)
s.netbird.SetClientLifecycle(s.inbound.onClientReady, s.inbound.onClientStop)
s.Logger.Info("private inbound listeners enabled (per-account :80 + :443)")
}
// notifyError reports a resource error back to management so it can be
// surfaced to the user (e.g. port bind failure, dialer resolution error).
func (s *Server) notifyError(ctx context.Context, mapping *proto.ProxyMapping, err error) {
@@ -941,9 +1145,6 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
Clock: backoff.SystemClock,
}
// syncSupported tracks whether management supports SyncMappings.
// Starts true; set to false on first Unimplemented error.
syncSupported := true
initialSyncDone := false
operation := func() error {
@@ -955,25 +1156,41 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
s.healthChecker.SetManagementConnected(false)
}
var streamErr error
if syncSupported {
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone)
if isSyncUnimplemented(streamErr) {
syncSupported = false
s.Logger.Info("management does not support SyncMappings, falling back to GetMappingUpdate")
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
}
} else {
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
supportsCrowdSec := s.crowdsecRegistry.Available()
privateCapability := s.Private
// Always true: this build enforces ProxyMapping.private via the auth middleware.
supportsPrivateService := true
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: s.ID,
Version: s.Version,
StartedAt: timestamppb.New(s.startTime),
Address: s.ProxyURL,
Capabilities: &proto.ProxyCapabilities{
SupportsCustomPorts: &s.SupportsCustomPorts,
RequireSubdomain: &s.RequireSubdomain,
SupportsCrowdsec: &supportsCrowdSec,
Private: &privateCapability,
SupportsPrivateService: &supportsPrivateService,
},
})
if err != nil {
return fmt.Errorf("create mapping stream: %w", err)
}
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(true)
}
s.Logger.Debug("management mapping stream established")
// Stream established — reset backoff so the next failure retries quickly.
bo.Reset()
streamErr := s.handleMappingStream(ctx, mappingClient, &initialSyncDone, time.Now())
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(false)
}
// Stream established — reset backoff so the next failure retries quickly.
bo.Reset()
if streamErr == nil {
return fmt.Errorf("stream closed by server")
}
@@ -990,187 +1207,56 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
}
}
func (s *Server) proxyCapabilities() *proto.ProxyCapabilities {
supportsCrowdSec := s.crowdsecRegistry.Available()
return &proto.ProxyCapabilities{
SupportsCustomPorts: &s.SupportsCustomPorts,
RequireSubdomain: &s.RequireSubdomain,
SupportsCrowdsec: &supportsCrowdSec,
}
}
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
connectTime := time.Now()
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: s.ID,
Version: s.Version,
StartedAt: timestamppb.New(s.startTime),
Address: s.ProxyURL,
Capabilities: s.proxyCapabilities(),
})
if err != nil {
return fmt.Errorf("create mapping stream: %w", err)
}
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(true)
}
s.Logger.Debug("management mapping stream established (GetMappingUpdate)")
return s.handleMappingStream(ctx, mappingClient, initialSyncDone, connectTime)
}
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
connectTime := time.Now()
stream, err := client.SyncMappings(ctx)
if err != nil {
return fmt.Errorf("create sync stream: %w", err)
}
// Send init message.
if err := stream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Init{
Init: &proto.SyncMappingsInit{
ProxyId: s.ID,
Version: s.Version,
StartedAt: timestamppb.New(s.startTime),
Address: s.ProxyURL,
Capabilities: s.proxyCapabilities(),
},
},
}); err != nil {
return fmt.Errorf("send sync init: %w", err)
}
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(true)
}
s.Logger.Debug("management mapping stream established (SyncMappings)")
return s.handleSyncMappingsStream(ctx, stream, initialSyncDone, connectTime)
}
func isSyncUnimplemented(err error) bool {
if err == nil {
return false
}
st, ok := grpcstatus.FromError(err)
return ok && st.Code() == codes.Unimplemented
}
func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.ProxyService_SyncMappingsClient, initialSyncDone *bool, connectTime time.Time) error {
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool, _ time.Time) error {
select {
case <-s.routerReady:
case <-ctx.Done():
return ctx.Err()
}
tracker := s.newSnapshotTracker(initialSyncDone, connectTime)
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
msg, err := stream.Recv()
switch {
case errors.Is(err, io.EOF):
return nil
case err != nil:
return fmt.Errorf("receive msg: %w", err)
}
batchStart := time.Now()
s.Logger.Debug("Received mapping update, starting processing")
s.processMappings(ctx, msg.GetMapping())
s.Logger.Debug("Processing mapping update completed")
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
if err := stream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
}); err != nil {
return fmt.Errorf("send ack: %w", err)
}
}
}
}
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool, connectTime time.Time) error {
select {
case <-s.routerReady:
case <-ctx.Done():
return ctx.Err()
var snapshotIDs map[types.ServiceID]struct{}
if !*initialSyncDone {
snapshotIDs = make(map[types.ServiceID]struct{})
}
tracker := s.newSnapshotTracker(initialSyncDone, connectTime)
for {
// Check for context completion to gracefully shutdown.
select {
case <-ctx.Done():
// Shutting down.
return ctx.Err()
default:
msg, err := mappingClient.Recv()
switch {
case errors.Is(err, io.EOF):
// Mapping connection gracefully terminated by server.
return nil
case err != nil:
// Something has gone horribly wrong, return and hope the parent retries the connection.
return fmt.Errorf("receive msg: %w", err)
}
batchStart := time.Now()
s.Logger.Debug("Received mapping update, starting processing")
s.processMappings(ctx, msg.GetMapping())
s.Logger.Debug("Processing mapping update completed")
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
if !*initialSyncDone {
for _, m := range msg.GetMapping() {
snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
}
if msg.GetInitialSyncComplete() {
s.reconcileSnapshot(ctx, snapshotIDs)
snapshotIDs = nil
if s.healthChecker != nil {
s.healthChecker.SetInitialSyncComplete()
}
*initialSyncDone = true
s.Logger.Info("Initial mapping sync complete")
}
}
}
}
}
// snapshotTracker accumulates service IDs during the initial snapshot and
// finalises sync state when the complete flag arrives.
type snapshotTracker struct {
done *bool
connectTime time.Time
snapshotIDs map[types.ServiceID]struct{}
}
func (s *Server) newSnapshotTracker(done *bool, connectTime time.Time) *snapshotTracker {
var ids map[types.ServiceID]struct{}
if !*done {
ids = make(map[types.ServiceID]struct{})
}
return &snapshotTracker{done: done, connectTime: connectTime, snapshotIDs: ids}
}
func (t *snapshotTracker) recordBatch(ctx context.Context, s *Server, mappings []*proto.ProxyMapping, syncComplete bool, batchStart time.Time) {
if *t.done {
return
}
if s.meter != nil {
s.meter.RecordSnapshotBatchDuration(time.Since(batchStart))
}
for _, m := range mappings {
t.snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
}
if !syncComplete {
return
}
s.reconcileSnapshot(ctx, t.snapshotIDs)
t.snapshotIDs = nil
if s.healthChecker != nil {
s.healthChecker.SetInitialSyncComplete()
}
*t.done = true
if s.meter != nil {
s.meter.RecordSnapshotSyncDuration(time.Since(t.connectTime))
}
s.Logger.Info("Initial mapping sync complete")
}
// reconcileSnapshot removes local mappings that are absent from the snapshot.
// This ensures services deleted while the proxy was disconnected get cleaned up.
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
@@ -1192,17 +1278,29 @@ func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.Se
}
}
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
s.ensurePeers(ctx, mappings)
// mappingJSONMarshal dumps mappings on one line with zero-value fields visible for debug logs.
var mappingJSONMarshal = protojson.MarshalOptions{
Multiline: false,
EmitUnpopulated: true,
UseProtoNames: true,
}
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
// The full proto dump carries auth_token and header-auth values; gate on debug.
debug := s.Logger != nil && s.Logger.IsLevelEnabled(log.DebugLevel)
for _, mapping := range mappings {
s.Logger.WithFields(log.Fields{
"type": mapping.GetType(),
"domain": mapping.GetDomain(),
"mode": mapping.GetMode(),
"port": mapping.GetListenPort(),
"id": mapping.GetId(),
}).Debug("Processing mapping update")
if debug {
raw, err := mappingJSONMarshal.Marshal(mapping)
if err != nil {
raw = []byte(fmt.Sprintf("<marshal error: %v>", err))
}
s.Logger.WithFields(log.Fields{
"type": mapping.GetType(),
"domain": mapping.GetDomain(),
"id": mapping.GetId(),
"mapping": string(raw),
}).Debug("Processing mapping update")
}
switch mapping.GetType() {
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
if err := s.addMapping(ctx, mapping); err != nil {
@@ -1228,60 +1326,6 @@ func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMap
}
}
// ensurePeers pre-creates NetBird peers for all unique accounts referenced by
// CREATED mappings. Peers for different accounts are created concurrently,
// which avoids serializing N×100ms gRPC round-trips during large initial syncs.
func (s *Server) ensurePeers(ctx context.Context, mappings []*proto.ProxyMapping) {
// Collect one representative mapping per account that needs a new peer.
type peerReq struct {
accountID types.AccountID
svcKey roundtrip.ServiceKey
authToken string
svcID types.ServiceID
}
seen := make(map[types.AccountID]struct{})
var reqs []peerReq
for _, m := range mappings {
if m.GetType() != proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED {
continue
}
accountID := types.AccountID(m.GetAccountId())
if _, ok := seen[accountID]; ok {
continue
}
seen[accountID] = struct{}{}
if s.netbird.HasClient(accountID) {
continue
}
reqs = append(reqs, peerReq{
accountID: accountID,
svcKey: s.serviceKeyForMapping(m),
authToken: m.GetAuthToken(),
svcID: types.ServiceID(m.GetId()),
})
}
if len(reqs) <= 1 {
return
}
var wg sync.WaitGroup
wg.Add(len(reqs))
for _, r := range reqs {
go func() {
defer wg.Done()
if err := s.netbird.AddPeer(ctx, r.accountID, r.svcKey, r.authToken, r.svcID); err != nil {
s.Logger.WithFields(log.Fields{
"account_id": r.accountID,
"service_id": r.svcID,
"error": err,
}).Warn("failed to pre-create peer for account")
}
}()
}
wg.Wait()
}
// addMapping registers a service mapping and starts the appropriate relay or routes.
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
accountID := types.AccountID(mapping.GetAccountId())
@@ -1353,12 +1397,16 @@ func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMappi
if s.acme != nil {
wildcardHit = s.acme.AddDomain(d, accountID, svcID)
}
s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{
httpRoute := nbtcp.Route{
Type: nbtcp.RouteHTTP,
AccountID: accountID,
ServiceID: svcID,
Domain: mapping.GetDomain(),
})
}
s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), httpRoute)
if s.inbound != nil {
s.inbound.AddRoute(accountID, nbtcp.SNIHost(mapping.GetDomain()), httpRoute)
}
if err := s.updateMapping(ctx, mapping); err != nil {
return fmt.Errorf("update mapping for domain %q: %w", d, err)
}
@@ -1718,7 +1766,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions); err != nil {
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions, mapping.GetPrivate()); err != nil {
return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err)
}
m := s.protoToMapping(ctx, mapping)
@@ -1774,6 +1822,9 @@ func (s *Server) cleanupMappingRoutes(mapping *proto.ProxyMapping) {
}
// Remove SNI route from the main router (covers both HTTP and main-port TLS).
s.mainRouter.RemoveRoute(nbtcp.SNIHost(host), svcID)
if s.inbound != nil {
s.inbound.RemoveRoute(types.AccountID(mapping.GetAccountId()), nbtcp.SNIHost(host), svcID)
}
}
// Extract and delete tracked custom-port entries atomically.
@@ -1861,6 +1912,7 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
if d := opts.GetRequestTimeout(); d != nil {
pt.RequestTimeout = d.AsDuration()
}
pt.DirectUpstream = opts.GetDirectUpstream()
}
pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout)
paths[pathMapping.GetPath()] = pt