mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-21 08:09:55 +00:00
Adds a new "private" service mode for the reverse proxy: services reachable exclusively over the embedded WireGuard tunnel, gated by per-peer group membership instead of operator auth schemes. Wire contract - ProxyMapping.private (field 13): the proxy MUST call ValidateTunnelPeer and fail closed; operator schemes are bypassed. - ProxyCapabilities.private (4) + supports_private_service (5): capability gate. Management never streams private mappings to proxies that don't claim the capability; the broadcast path applies the same filter via filterMappingsForProxy. - ValidateTunnelPeer RPC: resolves an inbound tunnel IP to a peer, checks the peer's groups against service.AccessGroups, and mints a session JWT on success. checkPeerGroupAccess fails closed when a private service has empty AccessGroups. - ValidateSession/ValidateTunnelPeer responses now carry peer_group_ids + peer_group_names so the proxy can authorise policy-aware middlewares without an extra management round-trip. - ProxyInboundListener + SendStatusUpdate.inbound_listener: per-account inbound listener state surfaced to dashboards. - PathTargetOptions.direct_upstream (11): bypass the embedded NetBird client and dial the target via the proxy host's network stack for upstreams reachable without WireGuard. Data model - Service.Private (bool) + Service.AccessGroups ([]string, JSON- serialised). Validate() rejects bearer auth on private services. Copy() deep-copies AccessGroups. pgx getServices loads the columns. - DomainConfig.Private threaded into the proxy auth middleware. Request handler routes private services through forwardWithTunnelPeer and returns 403 on validation failure. - Account-level SynthesizePrivateServiceZones (synthetic DNS) and injectPrivateServicePolicies (synthetic ACL) gate on len(svc.AccessGroups) > 0. Proxy - /netbird proxy --private (embedded mode) flag; Config.Private in proxy/lifecycle.go. - Per-account inbound listener (proxy/inbound.go) binding HTTP/HTTPS on the embedded NetBird client's WireGuard tunnel netstack. - proxy/internal/auth/tunnel_cache: ValidateTunnelPeer response cache with single-flight de-duplication and per-account eviction. - Local peerstore short-circuit: when the inbound IP isn't in the account roster, deny fast without an RPC. - proxy/server.go reports SupportsPrivateService=true and redacts the full ProxyMapping JSON from info logs (auth_token + header-auth hashed values now only at debug level). Identity forwarding - ValidateSessionJWT returns user_id, email, method, groups, group_names. sessionkey.Claims carries Email + Groups + GroupNames so the proxy can stamp identity onto upstream requests without an extra management round-trip on every cookie-bearing request. - CapturedData carries userEmail / userGroups / userGroupNames; the proxy stamps X-NetBird-User and X-NetBird-Groups on r.Out from the authenticated identity (strips client-supplied values first to prevent spoofing). - AccessLog.UserGroups: access-log enrichment captures the user's group memberships at write time so the dashboard can render group context without reverse-resolving stale memberships. OpenAPI/dashboard surface - ReverseProxyService gains private + access_groups; ReverseProxyCluster gains private + supports_private. ReverseProxyTarget target_type enum gains "cluster". ServiceTargetOptions gains direct_upstream. ProxyAccessLog gains user_groups.
748 lines
22 KiB
Go
748 lines
22 KiB
Go
package tcp
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
|
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
|
"github.com/netbirdio/netbird/proxy/internal/types"
|
|
"github.com/netbirdio/netbird/util/netrelay"
|
|
)
|
|
|
|
// defaultDialTimeout is the fallback dial timeout when no per-route
|
|
// timeout is configured.
|
|
const defaultDialTimeout = 30 * time.Second
|
|
|
|
// errAccessRestricted is returned by relayTCP for access restriction
|
|
// denials so callers can skip warn-level logging (already logged at debug).
|
|
var errAccessRestricted = errors.New("rejected by access restrictions")
|
|
|
|
// SNIHost is a typed key for SNI hostname lookups.
|
|
type SNIHost string
|
|
|
|
// RouteType specifies how a connection should be handled.
|
|
type RouteType int
|
|
|
|
const (
|
|
// RouteHTTP routes the connection through the HTTP reverse proxy.
|
|
RouteHTTP RouteType = iota
|
|
// RouteTCP relays the connection directly to the backend (TLS passthrough).
|
|
RouteTCP
|
|
)
|
|
|
|
const (
|
|
// sniPeekTimeout is the deadline for reading the TLS ClientHello.
|
|
sniPeekTimeout = 5 * time.Second
|
|
// DefaultDrainTimeout is the default grace period for in-flight relay
|
|
// connections to finish during shutdown.
|
|
DefaultDrainTimeout = 30 * time.Second
|
|
// DefaultMaxRelayConns is the default cap on concurrent TCP relay connections per router.
|
|
DefaultMaxRelayConns = 4096
|
|
// httpChannelBuffer is the capacity of the channel feeding HTTP connections.
|
|
httpChannelBuffer = 4096
|
|
)
|
|
|
|
// DialResolver returns a DialContextFunc for the given account.
|
|
type DialResolver func(accountID types.AccountID) (types.DialContextFunc, error)
|
|
|
|
// Route describes where a connection for a given SNI should be sent.
|
|
type Route struct {
|
|
Type RouteType
|
|
AccountID types.AccountID
|
|
ServiceID types.ServiceID
|
|
// Domain is the service's configured domain, used for access log entries.
|
|
Domain string
|
|
// Protocol is the frontend protocol (tcp, tls), used for access log entries.
|
|
Protocol accesslog.Protocol
|
|
// Target is the backend address for TCP relay (e.g. "10.0.0.5:5432").
|
|
Target string
|
|
// ProxyProtocol enables sending a PROXY protocol v2 header to the backend.
|
|
ProxyProtocol bool
|
|
// DialTimeout overrides the default dial timeout for this route.
|
|
// Zero uses defaultDialTimeout.
|
|
DialTimeout time.Duration
|
|
// SessionIdleTimeout overrides the default idle timeout for relay connections.
|
|
// Zero uses DefaultIdleTimeout.
|
|
SessionIdleTimeout time.Duration
|
|
// Filter holds connection-level IP/geo restrictions. Nil means no restrictions.
|
|
Filter *restrict.Filter
|
|
}
|
|
|
|
// l4Logger sends layer-4 access log entries to the management server.
|
|
type l4Logger interface {
|
|
LogL4(entry accesslog.L4Entry)
|
|
}
|
|
|
|
// RelayObserver receives callbacks for TCP relay lifecycle events.
|
|
// All methods must be safe for concurrent use.
|
|
type RelayObserver interface {
|
|
TCPRelayStarted(accountID types.AccountID)
|
|
TCPRelayEnded(accountID types.AccountID, duration time.Duration, srcToDst, dstToSrc int64)
|
|
TCPRelayDialError(accountID types.AccountID)
|
|
TCPRelayRejected(accountID types.AccountID)
|
|
}
|
|
|
|
// Router accepts raw TCP connections on a shared listener, peeks at
|
|
// the TLS ClientHello to extract the SNI, and routes the connection
|
|
// to either the HTTP reverse proxy or a direct TCP relay.
|
|
type Router struct {
|
|
logger *log.Logger
|
|
// httpCh is immutable after construction: set only in NewRouter, nil in NewPortRouter.
|
|
httpCh chan net.Conn
|
|
httpListener *chanListener
|
|
// httpPlainCh feeds non-TLS HTTP connections to a parallel http.Server.
|
|
// Set only when NewRouter is called with WithPlainHTTP option (used by
|
|
// per-account inbound listeners that accept both :80 and :443 traffic).
|
|
// Nil for the host SNI router and for port routers.
|
|
httpPlainCh chan net.Conn
|
|
httpPlainListener *chanListener
|
|
mu sync.RWMutex
|
|
routes map[SNIHost][]Route
|
|
fallback *Route
|
|
draining bool
|
|
dialResolve DialResolver
|
|
activeConns sync.WaitGroup
|
|
activeRelays sync.WaitGroup
|
|
relaySem chan struct{}
|
|
drainDone chan struct{}
|
|
observer RelayObserver
|
|
accessLog l4Logger
|
|
geo restrict.GeoResolver
|
|
// svcCtxs tracks a context per service ID. All relay goroutines for a
|
|
// service derive from its context; canceling it kills them immediately.
|
|
svcCtxs map[types.ServiceID]context.Context
|
|
svcCancels map[types.ServiceID]context.CancelFunc
|
|
}
|
|
|
|
// RouterOption customises Router construction.
|
|
type RouterOption func(*Router)
|
|
|
|
// WithPlainHTTP enables a parallel plain-HTTP channel on the router. When
|
|
// set, connections whose first byte is not a TLS handshake are forwarded
|
|
// to the plain channel returned by HTTPListenerPlain instead of the TLS
|
|
// channel. Used by per-account inbound listeners that share both :80 and
|
|
// :443 traffic on the same router.
|
|
func WithPlainHTTP(addr net.Addr) RouterOption {
|
|
return func(r *Router) {
|
|
ch := make(chan net.Conn, httpChannelBuffer)
|
|
r.httpPlainCh = ch
|
|
r.httpPlainListener = newChanListener(ch, addr)
|
|
}
|
|
}
|
|
|
|
// NewRouter creates a new SNI-based connection router.
|
|
func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr, opts ...RouterOption) *Router {
|
|
httpCh := make(chan net.Conn, httpChannelBuffer)
|
|
r := &Router{
|
|
logger: logger,
|
|
httpCh: httpCh,
|
|
httpListener: newChanListener(httpCh, addr),
|
|
routes: make(map[SNIHost][]Route),
|
|
dialResolve: dialResolve,
|
|
relaySem: make(chan struct{}, DefaultMaxRelayConns),
|
|
svcCtxs: make(map[types.ServiceID]context.Context),
|
|
svcCancels: make(map[types.ServiceID]context.CancelFunc),
|
|
}
|
|
for _, opt := range opts {
|
|
opt(r)
|
|
}
|
|
return r
|
|
}
|
|
|
|
// NewPortRouter creates a Router for a dedicated port without an HTTP
|
|
// channel. Connections that don't match any SNI route fall through to
|
|
// the fallback relay (if set) or are closed.
|
|
func NewPortRouter(logger *log.Logger, dialResolve DialResolver) *Router {
|
|
return &Router{
|
|
logger: logger,
|
|
routes: make(map[SNIHost][]Route),
|
|
dialResolve: dialResolve,
|
|
relaySem: make(chan struct{}, DefaultMaxRelayConns),
|
|
svcCtxs: make(map[types.ServiceID]context.Context),
|
|
svcCancels: make(map[types.ServiceID]context.CancelFunc),
|
|
}
|
|
}
|
|
|
|
// HTTPListener returns a net.Listener that yields connections routed
|
|
// to the HTTP handler. Use this with http.Server.ServeTLS.
|
|
func (r *Router) HTTPListener() net.Listener {
|
|
return r.httpListener
|
|
}
|
|
|
|
// HTTPListenerPlain returns a net.Listener yielding non-TLS connections
|
|
// for use with a parallel plain http.Server. Returns nil when the router
|
|
// was not constructed with WithPlainHTTP.
|
|
func (r *Router) HTTPListenerPlain() net.Listener {
|
|
if r.httpPlainListener == nil {
|
|
return nil
|
|
}
|
|
return r.httpPlainListener
|
|
}
|
|
|
|
// AddRoute registers an SNI route. Multiple routes for the same host are
|
|
// stored and resolved by priority at lookup time (HTTP > TCP).
|
|
// Empty host is ignored to prevent conflicts with ECH/ESNI fallback.
|
|
func (r *Router) AddRoute(host SNIHost, route Route) {
|
|
host = SNIHost(strings.ToLower(string(host)))
|
|
if host == "" {
|
|
return
|
|
}
|
|
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
|
|
routes := r.routes[host]
|
|
for i, existing := range routes {
|
|
if existing.ServiceID == route.ServiceID {
|
|
r.cancelServiceLocked(route.ServiceID)
|
|
routes[i] = route
|
|
return
|
|
}
|
|
}
|
|
r.routes[host] = append(routes, route)
|
|
}
|
|
|
|
// RemoveRoute removes the route for the given host and service ID.
|
|
// Active relay connections for the service are closed immediately.
|
|
// If other routes remain for the host, they are preserved.
|
|
func (r *Router) RemoveRoute(host SNIHost, svcID types.ServiceID) {
|
|
host = SNIHost(strings.ToLower(string(host)))
|
|
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
|
|
r.routes[host] = slices.DeleteFunc(r.routes[host], func(route Route) bool {
|
|
return route.ServiceID == svcID
|
|
})
|
|
if len(r.routes[host]) == 0 {
|
|
delete(r.routes, host)
|
|
}
|
|
r.cancelServiceLocked(svcID)
|
|
}
|
|
|
|
// SetFallback registers a catch-all route for connections that don't
|
|
// match any SNI route. On a port router this handles plain TCP relay;
|
|
// on the main router it takes priority over the HTTP channel.
|
|
func (r *Router) SetFallback(route Route) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.fallback = &route
|
|
}
|
|
|
|
// RemoveFallback clears the catch-all fallback route and closes any
|
|
// active relay connections for the given service.
|
|
func (r *Router) RemoveFallback(svcID types.ServiceID) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.fallback = nil
|
|
r.cancelServiceLocked(svcID)
|
|
}
|
|
|
|
// SetObserver sets the relay lifecycle observer. Must be called before Serve.
|
|
func (r *Router) SetObserver(obs RelayObserver) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.observer = obs
|
|
}
|
|
|
|
// SetAccessLogger sets the L4 access logger. Must be called before Serve.
|
|
func (r *Router) SetAccessLogger(l l4Logger) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.accessLog = l
|
|
}
|
|
|
|
// getObserver returns the current relay observer under the read lock.
|
|
func (r *Router) getObserver() RelayObserver {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
return r.observer
|
|
}
|
|
|
|
// IsEmpty returns true when the router has no SNI routes and no fallback.
|
|
func (r *Router) IsEmpty() bool {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
return len(r.routes) == 0 && r.fallback == nil
|
|
}
|
|
|
|
// Serve accepts connections from ln and routes them based on SNI.
|
|
// It blocks until ctx is canceled or ln is closed, then drains
|
|
// active relay connections up to DefaultDrainTimeout.
|
|
func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
|
|
done := make(chan struct{})
|
|
defer close(done)
|
|
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = ln.Close()
|
|
if r.httpListener != nil {
|
|
r.httpListener.Close()
|
|
}
|
|
if r.httpPlainListener != nil {
|
|
r.httpPlainListener.Close()
|
|
}
|
|
case <-done:
|
|
}
|
|
}()
|
|
|
|
for {
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
|
|
if ok := r.Drain(DefaultDrainTimeout); !ok {
|
|
r.logger.Warn("timed out waiting for connections to drain")
|
|
}
|
|
return nil
|
|
}
|
|
r.logger.Debugf("SNI router accept: %v", err)
|
|
continue
|
|
}
|
|
r.logger.Debugf("SNI router accepted conn from %s on %s", conn.RemoteAddr(), conn.LocalAddr())
|
|
r.activeConns.Add(1)
|
|
go func() {
|
|
defer r.activeConns.Done()
|
|
r.handleConn(ctx, conn)
|
|
}()
|
|
}
|
|
}
|
|
|
|
// HandleConn lets external accept loops feed a connection through the
|
|
// router's peek-and-dispatch logic. Use this when the same router serves
|
|
// a secondary listener (for example, a per-account inbound :80 socket
|
|
// alongside its :443 socket).
|
|
func (r *Router) HandleConn(ctx context.Context, conn net.Conn) {
|
|
r.activeConns.Add(1)
|
|
defer r.activeConns.Done()
|
|
r.handleConn(ctx, conn)
|
|
}
|
|
|
|
// handleConn peeks at the TLS ClientHello and routes the connection.
|
|
func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
|
|
// Fast path: when no SNI routes and no HTTP channel exist (pure TCP
|
|
// fallback port), skip the TLS peek entirely to avoid read errors on
|
|
// non-TLS connections and reduce latency.
|
|
if r.isFallbackOnly() {
|
|
r.logger.Debugf("SNI router fallback-only mode for conn from %s; skipping ClientHello peek", conn.RemoteAddr())
|
|
r.handleUnmatched(ctx, conn, false)
|
|
return
|
|
}
|
|
|
|
if err := conn.SetReadDeadline(time.Now().Add(sniPeekTimeout)); err != nil {
|
|
r.logger.Debugf("set SNI peek deadline: %v", err)
|
|
_ = conn.Close()
|
|
return
|
|
}
|
|
|
|
sni, wrapped, isTLS, err := PeekClientHello(conn)
|
|
if err != nil {
|
|
r.logger.Debugf("SNI peek failed for conn from %s: %v", conn.RemoteAddr(), err)
|
|
if wrapped != nil {
|
|
r.handleUnmatched(ctx, wrapped, isTLS)
|
|
} else {
|
|
_ = conn.Close()
|
|
}
|
|
return
|
|
}
|
|
|
|
if err := wrapped.SetReadDeadline(time.Time{}); err != nil {
|
|
r.logger.Debugf("clear SNI peek deadline: %v", err)
|
|
_ = wrapped.Close()
|
|
return
|
|
}
|
|
|
|
host := SNIHost(strings.ToLower(sni))
|
|
route, ok := r.lookupRoute(host)
|
|
r.logger.WithFields(log.Fields{
|
|
"remote": wrapped.RemoteAddr().String(),
|
|
"sni": string(host),
|
|
"match": ok,
|
|
"tls": isTLS,
|
|
}).Debug("SNI route lookup")
|
|
if !ok {
|
|
r.handleUnmatched(ctx, wrapped, isTLS)
|
|
return
|
|
}
|
|
|
|
if route.Type == RouteHTTP {
|
|
r.logger.Debugf("SNI %q routed to HTTP handler (service_id=%s)", host, route.ServiceID)
|
|
r.sendToHTTP(wrapped, isTLS)
|
|
return
|
|
}
|
|
|
|
if err := r.relayTCP(ctx, wrapped, host, route); err != nil {
|
|
if !errors.Is(err, errAccessRestricted) {
|
|
r.logger.WithFields(log.Fields{
|
|
"sni": host,
|
|
"service_id": route.ServiceID,
|
|
"target": route.Target,
|
|
}).Warnf("TCP relay: %v", err)
|
|
}
|
|
_ = wrapped.Close()
|
|
}
|
|
}
|
|
|
|
// isFallbackOnly returns true when the router has no SNI routes and no HTTP
|
|
// channel, meaning all connections should go directly to the fallback relay.
|
|
func (r *Router) isFallbackOnly() bool {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
return len(r.routes) == 0 && r.httpCh == nil
|
|
}
|
|
|
|
// handleUnmatched routes a connection that didn't match any SNI route.
|
|
// This includes ECH/ESNI connections where the cleartext SNI is empty,
|
|
// and plain (non-TLS) HTTP connections when isTLS is false.
|
|
// It tries the fallback relay first, then the HTTP channel, and closes
|
|
// the connection if neither is available.
|
|
func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn, isTLS bool) {
|
|
r.mu.RLock()
|
|
fb := r.fallback
|
|
r.mu.RUnlock()
|
|
|
|
if fb != nil {
|
|
r.logger.Debugf("unmatched conn from %s relayed to TCP fallback (service_id=%s, target=%s)", conn.RemoteAddr(), fb.ServiceID, fb.Target)
|
|
if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil {
|
|
if !errors.Is(err, errAccessRestricted) {
|
|
r.logger.WithFields(log.Fields{
|
|
"service_id": fb.ServiceID,
|
|
"target": fb.Target,
|
|
}).Warnf("TCP relay (fallback): %v", err)
|
|
}
|
|
_ = conn.Close()
|
|
}
|
|
return
|
|
}
|
|
r.logger.Debugf("unmatched conn from %s sent to HTTP channel (no TCP fallback configured)", conn.RemoteAddr())
|
|
r.sendToHTTP(conn, isTLS)
|
|
}
|
|
|
|
// lookupRoute returns the highest-priority route for the given SNI host.
|
|
// HTTP routes take precedence over TCP routes.
|
|
func (r *Router) lookupRoute(host SNIHost) (Route, bool) {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
routes, ok := r.routes[host]
|
|
if !ok || len(routes) == 0 {
|
|
return Route{}, false
|
|
}
|
|
best := routes[0]
|
|
for _, route := range routes[1:] {
|
|
if route.Type < best.Type {
|
|
best = route
|
|
}
|
|
}
|
|
return best, true
|
|
}
|
|
|
|
// sendToHTTP feeds the connection to the HTTP handler via the channel.
|
|
// When isTLS is false and a plain channel is configured the connection
|
|
// is forwarded to the plain channel; otherwise it lands on the TLS
|
|
// channel. If no usable channel exists, the router is draining, or the
|
|
// channel is full, the connection is closed.
|
|
func (r *Router) sendToHTTP(conn net.Conn, isTLS bool) {
|
|
ch := r.httpCh
|
|
chanName := "HTTP"
|
|
if !isTLS && r.httpPlainCh != nil {
|
|
ch = r.httpPlainCh
|
|
chanName = "HTTP-plain"
|
|
}
|
|
|
|
if ch == nil {
|
|
r.logger.Debugf("%s channel nil; dropping conn from %s", chanName, conn.RemoteAddr())
|
|
_ = conn.Close()
|
|
return
|
|
}
|
|
|
|
r.mu.RLock()
|
|
draining := r.draining
|
|
r.mu.RUnlock()
|
|
|
|
if draining {
|
|
r.logger.Debugf("router draining; dropping conn from %s", conn.RemoteAddr())
|
|
_ = conn.Close()
|
|
return
|
|
}
|
|
|
|
select {
|
|
case ch <- conn:
|
|
default:
|
|
r.logger.Warnf("%s channel full, dropping connection from %s", chanName, conn.RemoteAddr())
|
|
_ = conn.Close()
|
|
}
|
|
}
|
|
|
|
// Drain prevents new relay connections from starting and waits for all
|
|
// in-flight connection handlers and active relays to finish, up to the
|
|
// given timeout. Returns true if all completed, false on timeout.
|
|
func (r *Router) Drain(timeout time.Duration) bool {
|
|
r.mu.Lock()
|
|
r.draining = true
|
|
if r.drainDone == nil {
|
|
done := make(chan struct{})
|
|
go func() {
|
|
r.activeConns.Wait()
|
|
r.activeRelays.Wait()
|
|
close(done)
|
|
}()
|
|
r.drainDone = done
|
|
}
|
|
done := r.drainDone
|
|
r.mu.Unlock()
|
|
|
|
select {
|
|
case <-done:
|
|
return true
|
|
case <-time.After(timeout):
|
|
return false
|
|
}
|
|
}
|
|
|
|
// cancelServiceLocked cancels and removes the context for the given service,
|
|
// closing all its active relay connections. Must be called with mu held.
|
|
func (r *Router) cancelServiceLocked(svcID types.ServiceID) {
|
|
if cancel, ok := r.svcCancels[svcID]; ok {
|
|
cancel()
|
|
delete(r.svcCtxs, svcID)
|
|
delete(r.svcCancels, svcID)
|
|
}
|
|
}
|
|
|
|
// SetGeo sets the geolocation lookup used for country-based restrictions.
|
|
func (r *Router) SetGeo(geo restrict.GeoResolver) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.geo = geo
|
|
}
|
|
|
|
// checkRestrictions evaluates the route's access filter against the
|
|
// connection's remote address. Returns Allow if the connection is
|
|
// permitted, or a deny verdict indicating the reason.
|
|
func (r *Router) checkRestrictions(conn net.Conn, route Route) restrict.Verdict {
|
|
if route.Filter == nil {
|
|
return restrict.Allow
|
|
}
|
|
|
|
addr, err := addrFromConn(conn)
|
|
if err != nil {
|
|
r.logger.Debugf("cannot parse client address %s for restriction check, denying", conn.RemoteAddr())
|
|
return restrict.DenyCIDR
|
|
}
|
|
|
|
r.mu.RLock()
|
|
geo := r.geo
|
|
r.mu.RUnlock()
|
|
|
|
return route.Filter.Check(addr, geo)
|
|
}
|
|
|
|
// relayTCP sets up and runs a bidirectional TCP relay.
|
|
// The caller owns conn and must close it if this method returns an error.
|
|
// On success (nil error), both conn and backend are closed by the relay.
|
|
func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error {
|
|
if verdict := r.checkRestrictions(conn, route); verdict != restrict.Allow {
|
|
if route.Filter != nil && route.Filter.IsObserveOnly(verdict) {
|
|
r.logger.Debugf("CrowdSec observe: would block %s for %s (%s)", conn.RemoteAddr(), sni, verdict)
|
|
r.logL4Deny(route, conn, verdict, true)
|
|
} else {
|
|
r.logger.Debugf("connection from %s rejected by access restrictions: %s", conn.RemoteAddr(), verdict)
|
|
r.logL4Deny(route, conn, verdict, false)
|
|
return errAccessRestricted
|
|
}
|
|
}
|
|
|
|
svcCtx, err := r.acquireRelay(ctx, route)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
<-r.relaySem
|
|
r.activeRelays.Done()
|
|
}()
|
|
|
|
backend, err := r.dialBackend(svcCtx, route)
|
|
if err != nil {
|
|
obs := r.getObserver()
|
|
if obs != nil {
|
|
obs.TCPRelayDialError(route.AccountID)
|
|
}
|
|
return err
|
|
}
|
|
|
|
if route.ProxyProtocol {
|
|
if err := writeProxyProtoV2(conn, backend); err != nil {
|
|
_ = backend.Close()
|
|
return fmt.Errorf("write PROXY protocol header: %w", err)
|
|
}
|
|
}
|
|
|
|
obs := r.getObserver()
|
|
if obs != nil {
|
|
obs.TCPRelayStarted(route.AccountID)
|
|
}
|
|
|
|
entry := r.logger.WithFields(log.Fields{
|
|
"sni": sni,
|
|
"service_id": route.ServiceID,
|
|
"target": route.Target,
|
|
})
|
|
entry.Debug("TCP relay started")
|
|
|
|
idleTimeout := route.SessionIdleTimeout
|
|
if idleTimeout <= 0 {
|
|
idleTimeout = netrelay.DefaultIdleTimeout
|
|
}
|
|
|
|
start := time.Now()
|
|
s2d, d2s := netrelay.Relay(svcCtx, conn, backend, netrelay.Options{
|
|
IdleTimeout: idleTimeout,
|
|
Logger: entry,
|
|
})
|
|
elapsed := time.Since(start)
|
|
|
|
if obs != nil {
|
|
obs.TCPRelayEnded(route.AccountID, elapsed, s2d, d2s)
|
|
}
|
|
entry.Debugf("TCP relay ended (client→backend: %d bytes, backend→client: %d bytes)", s2d, d2s)
|
|
|
|
r.logL4Entry(route, conn, elapsed, s2d, d2s)
|
|
return nil
|
|
}
|
|
|
|
// acquireRelay checks draining state, increments activeRelays, and acquires
|
|
// a semaphore slot. Returns the per-service context on success.
|
|
// The caller must release the semaphore and call activeRelays.Done() when done.
|
|
func (r *Router) acquireRelay(ctx context.Context, route Route) (context.Context, error) {
|
|
r.mu.Lock()
|
|
if r.draining {
|
|
r.mu.Unlock()
|
|
return nil, errors.New("router is draining")
|
|
}
|
|
r.activeRelays.Add(1)
|
|
svcCtx := r.getOrCreateServiceCtxLocked(ctx, route.ServiceID)
|
|
r.mu.Unlock()
|
|
|
|
select {
|
|
case r.relaySem <- struct{}{}:
|
|
return svcCtx, nil
|
|
default:
|
|
r.activeRelays.Done()
|
|
obs := r.getObserver()
|
|
if obs != nil {
|
|
obs.TCPRelayRejected(route.AccountID)
|
|
}
|
|
return nil, errors.New("TCP relay connection limit reached")
|
|
}
|
|
}
|
|
|
|
// dialBackend resolves the dialer for the route's account and dials the backend.
|
|
func (r *Router) dialBackend(svcCtx context.Context, route Route) (net.Conn, error) {
|
|
dialFn, err := r.dialResolve(route.AccountID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("resolve dialer: %w", err)
|
|
}
|
|
|
|
dialTimeout := route.DialTimeout
|
|
if dialTimeout <= 0 {
|
|
dialTimeout = defaultDialTimeout
|
|
}
|
|
dialCtx, dialCancel := context.WithTimeout(svcCtx, dialTimeout)
|
|
backend, err := dialFn(dialCtx, "tcp", route.Target)
|
|
dialCancel()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("dial backend %s: %w", route.Target, err)
|
|
}
|
|
return backend, nil
|
|
}
|
|
|
|
// logL4Entry sends a TCP relay access log entry if an access logger is configured.
|
|
func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration, bytesUp, bytesDown int64) {
|
|
r.mu.RLock()
|
|
al := r.accessLog
|
|
r.mu.RUnlock()
|
|
|
|
if al == nil {
|
|
return
|
|
}
|
|
|
|
sourceIP, _ := addrFromConn(conn)
|
|
|
|
al.LogL4(accesslog.L4Entry{
|
|
AccountID: route.AccountID,
|
|
ServiceID: route.ServiceID,
|
|
Protocol: route.Protocol,
|
|
Host: route.Domain,
|
|
SourceIP: sourceIP,
|
|
DurationMs: duration.Milliseconds(),
|
|
BytesUpload: bytesUp,
|
|
BytesDownload: bytesDown,
|
|
})
|
|
}
|
|
|
|
// logL4Deny sends an access log entry for a denied connection.
|
|
func (r *Router) logL4Deny(route Route, conn net.Conn, verdict restrict.Verdict, observeOnly bool) {
|
|
r.mu.RLock()
|
|
al := r.accessLog
|
|
r.mu.RUnlock()
|
|
|
|
if al == nil {
|
|
return
|
|
}
|
|
|
|
sourceIP, _ := addrFromConn(conn)
|
|
|
|
entry := accesslog.L4Entry{
|
|
AccountID: route.AccountID,
|
|
ServiceID: route.ServiceID,
|
|
Protocol: route.Protocol,
|
|
Host: route.Domain,
|
|
SourceIP: sourceIP,
|
|
DenyReason: verdict.String(),
|
|
}
|
|
if verdict.IsCrowdSec() {
|
|
entry.Metadata = map[string]string{"crowdsec_verdict": verdict.String()}
|
|
if observeOnly {
|
|
entry.Metadata["crowdsec_mode"] = "observe"
|
|
entry.DenyReason = ""
|
|
}
|
|
}
|
|
al.LogL4(entry)
|
|
}
|
|
|
|
// getOrCreateServiceCtxLocked returns the context for a service, creating one
|
|
// if it doesn't exist yet. The context is a child of the server context.
|
|
// Must be called with mu held.
|
|
func (r *Router) getOrCreateServiceCtxLocked(parent context.Context, svcID types.ServiceID) context.Context {
|
|
if ctx, ok := r.svcCtxs[svcID]; ok {
|
|
return ctx
|
|
}
|
|
ctx, cancel := context.WithCancel(parent)
|
|
r.svcCtxs[svcID] = ctx
|
|
r.svcCancels[svcID] = cancel
|
|
return ctx
|
|
}
|
|
|
|
// addrFromConn extracts a netip.Addr from a connection's remote address.
|
|
func addrFromConn(conn net.Conn) (netip.Addr, error) {
|
|
remote := conn.RemoteAddr()
|
|
if remote == nil {
|
|
return netip.Addr{}, errors.New("no remote address")
|
|
}
|
|
ap, err := netip.ParseAddrPort(remote.String())
|
|
if err != nil {
|
|
return netip.Addr{}, err
|
|
}
|
|
return ap.Addr().Unmap(), nil
|
|
}
|