[proxy, management] Add header auth, access restrictions, and session idle timeout (#5587)

This commit is contained in:
Viktor Liu
2026-03-16 22:22:00 +08:00
committed by GitHub
parent 3e6baea405
commit 387e374e4b
34 changed files with 3509 additions and 1380 deletions

View File

@@ -43,12 +43,14 @@ import (
"github.com/netbirdio/netbird/proxy/internal/certwatch"
"github.com/netbirdio/netbird/proxy/internal/conntrack"
"github.com/netbirdio/netbird/proxy/internal/debug"
"github.com/netbirdio/netbird/proxy/internal/geolocation"
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/k8s"
proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics"
"github.com/netbirdio/netbird/proxy/internal/netutil"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
"github.com/netbirdio/netbird/proxy/internal/types"
@@ -59,7 +61,6 @@ import (
"github.com/netbirdio/netbird/util/embeddedroots"
)
// portRouter bundles a per-port Router with its listener and cancel func.
type portRouter struct {
router *nbtcp.Router
@@ -95,6 +96,9 @@ type Server struct {
// so they can be closed during graceful shutdown, since http.Server.Shutdown
// does not handle them.
hijackTracker conntrack.HijackTracker
// geo resolves IP addresses to country/city for access restrictions and access logs.
geo restrict.GeoResolver
geoRaw *geolocation.Lookup
// routerReady is closed once mainRouter is fully initialized.
// The mapping worker waits on this before processing updates.
@@ -159,10 +163,38 @@ type Server struct {
// SupportsCustomPorts indicates whether the proxy can bind arbitrary
// ports for TCP/UDP/TLS services.
SupportsCustomPorts bool
// DefaultDialTimeout is the default timeout for establishing backend
// connections when no per-service timeout is configured. Zero means
// each transport uses its own hardcoded default (typically 30s).
DefaultDialTimeout time.Duration
// MaxDialTimeout caps the per-service backend dial timeout.
// When the API sends a timeout, it is clamped to this value.
// When the API sends no timeout, this value is used as the default.
// Zero means no cap (the proxy honors whatever management sends).
MaxDialTimeout time.Duration
// GeoDataDir is the directory containing GeoLite2 MMDB files for
// country-based access restrictions. Empty disables geo lookups.
GeoDataDir string
// MaxSessionIdleTimeout caps the per-service session idle timeout.
// Zero means no cap (the proxy honors whatever management sends).
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
MaxSessionIdleTimeout time.Duration
}
// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured.
func (s *Server) clampIdleTimeout(d time.Duration) time.Duration {
if s.MaxSessionIdleTimeout > 0 && d > s.MaxSessionIdleTimeout {
return s.MaxSessionIdleTimeout
}
return d
}
// clampDialTimeout returns d capped to MaxDialTimeout when configured.
// If d is zero, MaxDialTimeout is used as the default.
func (s *Server) clampDialTimeout(d time.Duration) time.Duration {
if s.MaxDialTimeout <= 0 {
return d
}
if d <= 0 || d > s.MaxDialTimeout {
return s.MaxDialTimeout
}
return d
}
// NotifyStatus sends a status update to management about tunnel connectivity.
@@ -226,7 +258,6 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
runCtx, runCancel := context.WithCancel(ctx)
defer runCancel()
go s.newManagementMappingWorker(runCtx, s.mgmtClient)
// Initialize the netbird client, this is required to build peer connections
// to proxy over.
@@ -236,6 +267,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
PreSharedKey: s.PreSharedKey,
}, s.Logger, s, s.mgmtClient)
// Create health checker before the mapping worker so it can track
// management connectivity from the first stream connection.
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
go s.newManagementMappingWorker(runCtx, s.mgmtClient)
tlsConfig, err := s.configureTLS(ctx)
if err != nil {
return err
@@ -244,14 +281,33 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
if err != nil {
return fmt.Errorf("initialize geolocation: %w", err)
}
s.geoRaw = geoLookup
if geoLookup != nil {
s.geo = geoLookup
}
var startupOK bool
defer func() {
if startupOK {
return
}
if s.geoRaw != nil {
if err := s.geoRaw.Close(); err != nil {
s.Logger.Debugf("close geolocation on startup failure: %v", err)
}
}
}()
// Configure the authentication middleware with session validator for OIDC group checks.
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient)
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient, s.geo)
// Configure Access logs to management server.
s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
s.startDebugEndpoint()
if err := s.startHealthServer(); err != nil {
@@ -294,6 +350,8 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
}
startupOK = true
httpsErr := make(chan error, 1)
go func() {
s.Logger.Debug("starting HTTPS server on SNI router HTTP channel")
@@ -691,6 +749,16 @@ func (s *Server) shutdownServices() {
s.portRouterWg.Wait()
wg.Wait()
if s.accessLog != nil {
s.accessLog.Close()
}
if s.geoRaw != nil {
if err := s.geoRaw.Close(); err != nil {
s.Logger.Debugf("close geolocation: %v", err)
}
}
}
// resolveDialFunc returns a DialContextFunc that dials through the
@@ -1073,15 +1141,20 @@ func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMappin
return fmt.Errorf("router for TCP port %d: %w", port, err)
}
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
router.SetGeo(s.geo)
router.SetFallback(nbtcp.Route{
Type: nbtcp.RouteTCP,
AccountID: accountID,
ServiceID: svcID,
Domain: mapping.GetDomain(),
Protocol: accesslog.ProtocolTCP,
Target: targetAddr,
ProxyProtocol: s.l4ProxyProtocol(mapping),
DialTimeout: s.l4DialTimeout(mapping),
Type: nbtcp.RouteTCP,
AccountID: accountID,
ServiceID: svcID,
Domain: mapping.GetDomain(),
Protocol: accesslog.ProtocolTCP,
Target: targetAddr,
ProxyProtocol: s.l4ProxyProtocol(mapping),
DialTimeout: s.l4DialTimeout(mapping),
SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
Filter: parseRestrictions(mapping),
})
s.portMu.Lock()
@@ -1108,6 +1181,8 @@ func (s *Server) setupUDPMapping(ctx context.Context, mapping *proto.ProxyMappin
return fmt.Errorf("empty target address for UDP service %s", svcID)
}
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
if err := s.addUDPRelay(ctx, mapping, targetAddr, port); err != nil {
return fmt.Errorf("UDP relay for service %s: %w", svcID, err)
}
@@ -1141,15 +1216,20 @@ func (s *Server) setupTLSMapping(ctx context.Context, mapping *proto.ProxyMappin
return fmt.Errorf("router for TLS port %d: %w", tlsPort, err)
}
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
router.SetGeo(s.geo)
router.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{
Type: nbtcp.RouteTCP,
AccountID: accountID,
ServiceID: svcID,
Domain: mapping.GetDomain(),
Protocol: accesslog.ProtocolTLS,
Target: targetAddr,
ProxyProtocol: s.l4ProxyProtocol(mapping),
DialTimeout: s.l4DialTimeout(mapping),
Type: nbtcp.RouteTCP,
AccountID: accountID,
ServiceID: svcID,
Domain: mapping.GetDomain(),
Protocol: accesslog.ProtocolTLS,
Target: targetAddr,
ProxyProtocol: s.l4ProxyProtocol(mapping),
DialTimeout: s.l4DialTimeout(mapping),
SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
Filter: parseRestrictions(mapping),
})
if tlsPort != s.mainPort {
@@ -1181,6 +1261,32 @@ func (s *Server) serviceKeyForMapping(mapping *proto.ProxyMapping) roundtrip.Ser
}
}
// parseRestrictions converts a proto mapping's access restrictions into
// a restrict.Filter. Returns nil if the mapping has no restrictions.
func parseRestrictions(mapping *proto.ProxyMapping) *restrict.Filter {
r := mapping.GetAccessRestrictions()
if r == nil {
return nil
}
return restrict.ParseFilter(r.GetAllowedCidrs(), r.GetBlockedCidrs(), r.GetAllowedCountries(), r.GetBlockedCountries())
}
// warnIfGeoUnavailable logs a warning if the mapping has country restrictions
// but the proxy has no geolocation database loaded. All requests to this
// service will be denied at runtime (fail-close).
func (s *Server) warnIfGeoUnavailable(domain string, r *proto.AccessRestrictions) {
if r == nil {
return
}
if len(r.GetAllowedCountries()) == 0 && len(r.GetBlockedCountries()) == 0 {
return
}
if s.geo != nil && s.geo.Available() {
return
}
s.Logger.Warnf("service %s has country restrictions but no geolocation database is loaded: all requests will be denied", domain)
}
// l4TargetAddress extracts and validates the target address from a mapping's
// first path entry. Returns empty string if no paths exist or the address is
// not a valid host:port.
@@ -1210,15 +1316,15 @@ func (s *Server) l4ProxyProtocol(mapping *proto.ProxyMapping) bool {
}
// l4DialTimeout returns the dial timeout from the first target's options,
// falling back to the server's DefaultDialTimeout.
// clamped to MaxDialTimeout.
func (s *Server) l4DialTimeout(mapping *proto.ProxyMapping) time.Duration {
paths := mapping.GetPath()
if len(paths) > 0 {
if d := paths[0].GetOptions().GetRequestTimeout(); d != nil {
return d.AsDuration()
return s.clampDialTimeout(d.AsDuration())
}
}
return s.DefaultDialTimeout
return s.clampDialTimeout(0)
}
// l4SessionIdleTimeout returns the configured session idle timeout from the
@@ -1254,7 +1360,9 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t
dialFn, err := s.resolveDialFunc(accountID)
if err != nil {
_ = listener.Close()
if err := listener.Close(); err != nil {
s.Logger.Debugf("close UDP listener on %s: %v", listenAddr, err)
}
return fmt.Errorf("resolve dialer for UDP: %w", err)
}
@@ -1273,8 +1381,10 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t
ServiceID: svcID,
DialFunc: dialFn,
DialTimeout: s.l4DialTimeout(mapping),
SessionTTL: l4SessionIdleTimeout(mapping),
SessionTTL: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
AccessLog: s.accessLog,
Filter: parseRestrictions(mapping),
Geo: s.geo,
})
relay.SetObserver(s.meter)
@@ -1306,9 +1416,15 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
if mapping.GetAuth().GetOidc() {
schemes = append(schemes, auth.NewOIDC(s.mgmtClient, svcID, accountID, s.ForwardedProto))
}
for _, ha := range mapping.GetAuth().GetHeaderAuths() {
schemes = append(schemes, auth.NewHeader(s.mgmtClient, svcID, accountID, ha.GetHeader()))
}
ipRestrictions := parseRestrictions(mapping)
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID); err != nil {
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions); err != nil {
return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err)
}
m := s.protoToMapping(ctx, mapping)
@@ -1449,12 +1565,10 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
pt.RequestTimeout = d.AsDuration()
}
}
if pt.RequestTimeout == 0 && s.DefaultDialTimeout > 0 {
pt.RequestTimeout = s.DefaultDialTimeout
}
pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout)
paths[pathMapping.GetPath()] = pt
}
return proxy.Mapping{
m := proxy.Mapping{
ID: types.ServiceID(mapping.GetId()),
AccountID: types.AccountID(mapping.GetAccountId()),
Host: mapping.GetDomain(),
@@ -1462,6 +1576,10 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
PassHostHeader: mapping.GetPassHostHeader(),
RewriteRedirects: mapping.GetRewriteRedirects(),
}
for _, ha := range mapping.GetAuth().GetHeaderAuths() {
m.StripAuthHeaders = append(m.StripAuthHeaders, ha.GetHeader())
}
return m
}
func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode {