[proxy, management] Add header auth, access restrictions, and session idle timeout (#5587)

This commit is contained in:
Viktor Liu
2026-03-16 22:22:00 +08:00
committed by GitHub
parent 3e6baea405
commit 387e374e4b
34 changed files with 3509 additions and 1380 deletions

View File

@@ -7,12 +7,14 @@ import (
"net"
"net/netip"
"slices"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/accesslog"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types"
)
@@ -20,6 +22,10 @@ import (
// timeout is configured.
const defaultDialTimeout = 30 * time.Second
// errAccessRestricted is returned by relayTCP for access restriction
// denials so callers can skip warn-level logging (already logged at debug).
var errAccessRestricted = errors.New("rejected by access restrictions")
// SNIHost is a typed key for SNI hostname lookups.
type SNIHost string
@@ -64,6 +70,11 @@ type Route struct {
// DialTimeout overrides the default dial timeout for this route.
// Zero uses defaultDialTimeout.
DialTimeout time.Duration
// SessionIdleTimeout overrides the default idle timeout for relay connections.
// Zero uses DefaultIdleTimeout.
SessionIdleTimeout time.Duration
// Filter holds connection-level IP/geo restrictions. Nil means no restrictions.
Filter *restrict.Filter
}
// l4Logger sends layer-4 access log entries to the management server.
@@ -99,6 +110,7 @@ type Router struct {
drainDone chan struct{}
observer RelayObserver
accessLog l4Logger
geo restrict.GeoResolver
// svcCtxs tracks a context per service ID. All relay goroutines for a
// service derive from its context; canceling it kills them immediately.
svcCtxs map[types.ServiceID]context.Context
@@ -144,6 +156,7 @@ func (r *Router) HTTPListener() net.Listener {
// stored and resolved by priority at lookup time (HTTP > TCP).
// Empty host is ignored to prevent conflicts with ECH/ESNI fallback.
func (r *Router) AddRoute(host SNIHost, route Route) {
host = SNIHost(strings.ToLower(string(host)))
if host == "" {
return
}
@@ -166,6 +179,8 @@ func (r *Router) AddRoute(host SNIHost, route Route) {
// Active relay connections for the service are closed immediately.
// If other routes remain for the host, they are preserved.
func (r *Router) RemoveRoute(host SNIHost, svcID types.ServiceID) {
host = SNIHost(strings.ToLower(string(host)))
r.mu.Lock()
defer r.mu.Unlock()
@@ -295,7 +310,7 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
return
}
host := SNIHost(sni)
host := SNIHost(strings.ToLower(sni))
route, ok := r.lookupRoute(host)
if !ok {
r.handleUnmatched(ctx, wrapped)
@@ -308,11 +323,13 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
}
if err := r.relayTCP(ctx, wrapped, host, route); err != nil {
r.logger.WithFields(log.Fields{
"sni": host,
"service_id": route.ServiceID,
"target": route.Target,
}).Warnf("TCP relay: %v", err)
if !errors.Is(err, errAccessRestricted) {
r.logger.WithFields(log.Fields{
"sni": host,
"service_id": route.ServiceID,
"target": route.Target,
}).Warnf("TCP relay: %v", err)
}
_ = wrapped.Close()
}
}
@@ -336,10 +353,12 @@ func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
if fb != nil {
if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil {
r.logger.WithFields(log.Fields{
"service_id": fb.ServiceID,
"target": fb.Target,
}).Warnf("TCP relay (fallback): %v", err)
if !errors.Is(err, errAccessRestricted) {
r.logger.WithFields(log.Fields{
"service_id": fb.ServiceID,
"target": fb.Target,
}).Warnf("TCP relay (fallback): %v", err)
}
_ = conn.Close()
}
return
@@ -427,10 +446,44 @@ func (r *Router) cancelServiceLocked(svcID types.ServiceID) {
}
}
// SetGeo sets the geolocation lookup used for country-based restrictions.
func (r *Router) SetGeo(geo restrict.GeoResolver) {
r.mu.Lock()
defer r.mu.Unlock()
r.geo = geo
}
// checkRestrictions evaluates the route's access filter against the
// connection's remote address. Returns Allow if the connection is
// permitted, or a deny verdict indicating the reason.
func (r *Router) checkRestrictions(conn net.Conn, route Route) restrict.Verdict {
if route.Filter == nil {
return restrict.Allow
}
addr, err := addrFromConn(conn)
if err != nil {
r.logger.Debugf("cannot parse client address %s for restriction check, denying", conn.RemoteAddr())
return restrict.DenyCIDR
}
r.mu.RLock()
geo := r.geo
r.mu.RUnlock()
return route.Filter.Check(addr, geo)
}
// relayTCP sets up and runs a bidirectional TCP relay.
// The caller owns conn and must close it if this method returns an error.
// On success (nil error), both conn and backend are closed by the relay.
func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error {
if verdict := r.checkRestrictions(conn, route); verdict != restrict.Allow {
r.logger.Debugf("connection from %s rejected by access restrictions: %s", conn.RemoteAddr(), verdict)
r.logL4Deny(route, conn, verdict)
return errAccessRestricted
}
svcCtx, err := r.acquireRelay(ctx, route)
if err != nil {
return err
@@ -468,8 +521,13 @@ func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route
})
entry.Debug("TCP relay started")
idleTimeout := route.SessionIdleTimeout
if idleTimeout <= 0 {
idleTimeout = DefaultIdleTimeout
}
start := time.Now()
s2d, d2s := Relay(svcCtx, entry, conn, backend, DefaultIdleTimeout)
s2d, d2s := Relay(svcCtx, entry, conn, backend, idleTimeout)
elapsed := time.Since(start)
if obs != nil {
@@ -537,12 +595,7 @@ func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration,
return
}
var sourceIP netip.Addr
if remote := conn.RemoteAddr(); remote != nil {
if ap, err := netip.ParseAddrPort(remote.String()); err == nil {
sourceIP = ap.Addr().Unmap()
}
}
sourceIP, _ := addrFromConn(conn)
al.LogL4(accesslog.L4Entry{
AccountID: route.AccountID,
@@ -556,6 +609,28 @@ func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration,
})
}
// logL4Deny sends an access log entry for a denied connection.
func (r *Router) logL4Deny(route Route, conn net.Conn, verdict restrict.Verdict) {
r.mu.RLock()
al := r.accessLog
r.mu.RUnlock()
if al == nil {
return
}
sourceIP, _ := addrFromConn(conn)
al.LogL4(accesslog.L4Entry{
AccountID: route.AccountID,
ServiceID: route.ServiceID,
Protocol: route.Protocol,
Host: route.Domain,
SourceIP: sourceIP,
DenyReason: verdict.String(),
})
}
// getOrCreateServiceCtxLocked returns the context for a service, creating one
// if it doesn't exist yet. The context is a child of the server context.
// Must be called with mu held.
@@ -568,3 +643,16 @@ func (r *Router) getOrCreateServiceCtxLocked(parent context.Context, svcID types
r.svcCancels[svcID] = cancel
return ctx
}
// addrFromConn extracts a netip.Addr from a connection's remote address.
func addrFromConn(conn net.Conn) (netip.Addr, error) {
remote := conn.RemoteAddr()
if remote == nil {
return netip.Addr{}, errors.New("no remote address")
}
ap, err := netip.ParseAddrPort(remote.String())
if err != nil {
return netip.Addr{}, err
}
return ap.Addr().Unmap(), nil
}