mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[proxy, management] Add header auth, access restrictions, and session idle timeout (#5587)
This commit is contained in:
188
proxy/server.go
188
proxy/server.go
@@ -43,12 +43,14 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/certwatch"
|
||||
"github.com/netbirdio/netbird/proxy/internal/conntrack"
|
||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||
"github.com/netbirdio/netbird/proxy/internal/geolocation"
|
||||
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||
"github.com/netbirdio/netbird/proxy/internal/k8s"
|
||||
proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics"
|
||||
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
@@ -59,7 +61,6 @@ import (
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
|
||||
// portRouter bundles a per-port Router with its listener and cancel func.
|
||||
type portRouter struct {
|
||||
router *nbtcp.Router
|
||||
@@ -95,6 +96,9 @@ type Server struct {
|
||||
// so they can be closed during graceful shutdown, since http.Server.Shutdown
|
||||
// does not handle them.
|
||||
hijackTracker conntrack.HijackTracker
|
||||
// geo resolves IP addresses to country/city for access restrictions and access logs.
|
||||
geo restrict.GeoResolver
|
||||
geoRaw *geolocation.Lookup
|
||||
|
||||
// routerReady is closed once mainRouter is fully initialized.
|
||||
// The mapping worker waits on this before processing updates.
|
||||
@@ -159,10 +163,38 @@ type Server struct {
|
||||
// SupportsCustomPorts indicates whether the proxy can bind arbitrary
|
||||
// ports for TCP/UDP/TLS services.
|
||||
SupportsCustomPorts bool
|
||||
// DefaultDialTimeout is the default timeout for establishing backend
|
||||
// connections when no per-service timeout is configured. Zero means
|
||||
// each transport uses its own hardcoded default (typically 30s).
|
||||
DefaultDialTimeout time.Duration
|
||||
// MaxDialTimeout caps the per-service backend dial timeout.
|
||||
// When the API sends a timeout, it is clamped to this value.
|
||||
// When the API sends no timeout, this value is used as the default.
|
||||
// Zero means no cap (the proxy honors whatever management sends).
|
||||
MaxDialTimeout time.Duration
|
||||
// GeoDataDir is the directory containing GeoLite2 MMDB files for
|
||||
// country-based access restrictions. Empty disables geo lookups.
|
||||
GeoDataDir string
|
||||
// MaxSessionIdleTimeout caps the per-service session idle timeout.
|
||||
// Zero means no cap (the proxy honors whatever management sends).
|
||||
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
|
||||
MaxSessionIdleTimeout time.Duration
|
||||
}
|
||||
|
||||
// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured.
|
||||
func (s *Server) clampIdleTimeout(d time.Duration) time.Duration {
|
||||
if s.MaxSessionIdleTimeout > 0 && d > s.MaxSessionIdleTimeout {
|
||||
return s.MaxSessionIdleTimeout
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// clampDialTimeout returns d capped to MaxDialTimeout when configured.
|
||||
// If d is zero, MaxDialTimeout is used as the default.
|
||||
func (s *Server) clampDialTimeout(d time.Duration) time.Duration {
|
||||
if s.MaxDialTimeout <= 0 {
|
||||
return d
|
||||
}
|
||||
if d <= 0 || d > s.MaxDialTimeout {
|
||||
return s.MaxDialTimeout
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// NotifyStatus sends a status update to management about tunnel connectivity.
|
||||
@@ -226,7 +258,6 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
|
||||
runCtx, runCancel := context.WithCancel(ctx)
|
||||
defer runCancel()
|
||||
go s.newManagementMappingWorker(runCtx, s.mgmtClient)
|
||||
|
||||
// Initialize the netbird client, this is required to build peer connections
|
||||
// to proxy over.
|
||||
@@ -236,6 +267,12 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
PreSharedKey: s.PreSharedKey,
|
||||
}, s.Logger, s, s.mgmtClient)
|
||||
|
||||
// Create health checker before the mapping worker so it can track
|
||||
// management connectivity from the first stream connection.
|
||||
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
||||
|
||||
go s.newManagementMappingWorker(runCtx, s.mgmtClient)
|
||||
|
||||
tlsConfig, err := s.configureTLS(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -244,14 +281,33 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
|
||||
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
|
||||
|
||||
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize geolocation: %w", err)
|
||||
}
|
||||
s.geoRaw = geoLookup
|
||||
if geoLookup != nil {
|
||||
s.geo = geoLookup
|
||||
}
|
||||
|
||||
var startupOK bool
|
||||
defer func() {
|
||||
if startupOK {
|
||||
return
|
||||
}
|
||||
if s.geoRaw != nil {
|
||||
if err := s.geoRaw.Close(); err != nil {
|
||||
s.Logger.Debugf("close geolocation on startup failure: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Configure the authentication middleware with session validator for OIDC group checks.
|
||||
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient)
|
||||
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient, s.geo)
|
||||
|
||||
// Configure Access logs to management server.
|
||||
s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
|
||||
|
||||
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
|
||||
|
||||
s.startDebugEndpoint()
|
||||
|
||||
if err := s.startHealthServer(); err != nil {
|
||||
@@ -294,6 +350,8 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
|
||||
}
|
||||
|
||||
startupOK = true
|
||||
|
||||
httpsErr := make(chan error, 1)
|
||||
go func() {
|
||||
s.Logger.Debug("starting HTTPS server on SNI router HTTP channel")
|
||||
@@ -691,6 +749,16 @@ func (s *Server) shutdownServices() {
|
||||
s.portRouterWg.Wait()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if s.accessLog != nil {
|
||||
s.accessLog.Close()
|
||||
}
|
||||
|
||||
if s.geoRaw != nil {
|
||||
if err := s.geoRaw.Close(); err != nil {
|
||||
s.Logger.Debugf("close geolocation: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// resolveDialFunc returns a DialContextFunc that dials through the
|
||||
@@ -1073,15 +1141,20 @@ func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMappin
|
||||
return fmt.Errorf("router for TCP port %d: %w", port, err)
|
||||
}
|
||||
|
||||
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
|
||||
|
||||
router.SetGeo(s.geo)
|
||||
router.SetFallback(nbtcp.Route{
|
||||
Type: nbtcp.RouteTCP,
|
||||
AccountID: accountID,
|
||||
ServiceID: svcID,
|
||||
Domain: mapping.GetDomain(),
|
||||
Protocol: accesslog.ProtocolTCP,
|
||||
Target: targetAddr,
|
||||
ProxyProtocol: s.l4ProxyProtocol(mapping),
|
||||
DialTimeout: s.l4DialTimeout(mapping),
|
||||
Type: nbtcp.RouteTCP,
|
||||
AccountID: accountID,
|
||||
ServiceID: svcID,
|
||||
Domain: mapping.GetDomain(),
|
||||
Protocol: accesslog.ProtocolTCP,
|
||||
Target: targetAddr,
|
||||
ProxyProtocol: s.l4ProxyProtocol(mapping),
|
||||
DialTimeout: s.l4DialTimeout(mapping),
|
||||
SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
|
||||
Filter: parseRestrictions(mapping),
|
||||
})
|
||||
|
||||
s.portMu.Lock()
|
||||
@@ -1108,6 +1181,8 @@ func (s *Server) setupUDPMapping(ctx context.Context, mapping *proto.ProxyMappin
|
||||
return fmt.Errorf("empty target address for UDP service %s", svcID)
|
||||
}
|
||||
|
||||
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
|
||||
|
||||
if err := s.addUDPRelay(ctx, mapping, targetAddr, port); err != nil {
|
||||
return fmt.Errorf("UDP relay for service %s: %w", svcID, err)
|
||||
}
|
||||
@@ -1141,15 +1216,20 @@ func (s *Server) setupTLSMapping(ctx context.Context, mapping *proto.ProxyMappin
|
||||
return fmt.Errorf("router for TLS port %d: %w", tlsPort, err)
|
||||
}
|
||||
|
||||
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
|
||||
|
||||
router.SetGeo(s.geo)
|
||||
router.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{
|
||||
Type: nbtcp.RouteTCP,
|
||||
AccountID: accountID,
|
||||
ServiceID: svcID,
|
||||
Domain: mapping.GetDomain(),
|
||||
Protocol: accesslog.ProtocolTLS,
|
||||
Target: targetAddr,
|
||||
ProxyProtocol: s.l4ProxyProtocol(mapping),
|
||||
DialTimeout: s.l4DialTimeout(mapping),
|
||||
Type: nbtcp.RouteTCP,
|
||||
AccountID: accountID,
|
||||
ServiceID: svcID,
|
||||
Domain: mapping.GetDomain(),
|
||||
Protocol: accesslog.ProtocolTLS,
|
||||
Target: targetAddr,
|
||||
ProxyProtocol: s.l4ProxyProtocol(mapping),
|
||||
DialTimeout: s.l4DialTimeout(mapping),
|
||||
SessionIdleTimeout: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
|
||||
Filter: parseRestrictions(mapping),
|
||||
})
|
||||
|
||||
if tlsPort != s.mainPort {
|
||||
@@ -1181,6 +1261,32 @@ func (s *Server) serviceKeyForMapping(mapping *proto.ProxyMapping) roundtrip.Ser
|
||||
}
|
||||
}
|
||||
|
||||
// parseRestrictions converts a proto mapping's access restrictions into
|
||||
// a restrict.Filter. Returns nil if the mapping has no restrictions.
|
||||
func parseRestrictions(mapping *proto.ProxyMapping) *restrict.Filter {
|
||||
r := mapping.GetAccessRestrictions()
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return restrict.ParseFilter(r.GetAllowedCidrs(), r.GetBlockedCidrs(), r.GetAllowedCountries(), r.GetBlockedCountries())
|
||||
}
|
||||
|
||||
// warnIfGeoUnavailable logs a warning if the mapping has country restrictions
|
||||
// but the proxy has no geolocation database loaded. All requests to this
|
||||
// service will be denied at runtime (fail-close).
|
||||
func (s *Server) warnIfGeoUnavailable(domain string, r *proto.AccessRestrictions) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
if len(r.GetAllowedCountries()) == 0 && len(r.GetBlockedCountries()) == 0 {
|
||||
return
|
||||
}
|
||||
if s.geo != nil && s.geo.Available() {
|
||||
return
|
||||
}
|
||||
s.Logger.Warnf("service %s has country restrictions but no geolocation database is loaded: all requests will be denied", domain)
|
||||
}
|
||||
|
||||
// l4TargetAddress extracts and validates the target address from a mapping's
|
||||
// first path entry. Returns empty string if no paths exist or the address is
|
||||
// not a valid host:port.
|
||||
@@ -1210,15 +1316,15 @@ func (s *Server) l4ProxyProtocol(mapping *proto.ProxyMapping) bool {
|
||||
}
|
||||
|
||||
// l4DialTimeout returns the dial timeout from the first target's options,
|
||||
// falling back to the server's DefaultDialTimeout.
|
||||
// clamped to MaxDialTimeout.
|
||||
func (s *Server) l4DialTimeout(mapping *proto.ProxyMapping) time.Duration {
|
||||
paths := mapping.GetPath()
|
||||
if len(paths) > 0 {
|
||||
if d := paths[0].GetOptions().GetRequestTimeout(); d != nil {
|
||||
return d.AsDuration()
|
||||
return s.clampDialTimeout(d.AsDuration())
|
||||
}
|
||||
}
|
||||
return s.DefaultDialTimeout
|
||||
return s.clampDialTimeout(0)
|
||||
}
|
||||
|
||||
// l4SessionIdleTimeout returns the configured session idle timeout from the
|
||||
@@ -1254,7 +1360,9 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t
|
||||
|
||||
dialFn, err := s.resolveDialFunc(accountID)
|
||||
if err != nil {
|
||||
_ = listener.Close()
|
||||
if err := listener.Close(); err != nil {
|
||||
s.Logger.Debugf("close UDP listener on %s: %v", listenAddr, err)
|
||||
}
|
||||
return fmt.Errorf("resolve dialer for UDP: %w", err)
|
||||
}
|
||||
|
||||
@@ -1273,8 +1381,10 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t
|
||||
ServiceID: svcID,
|
||||
DialFunc: dialFn,
|
||||
DialTimeout: s.l4DialTimeout(mapping),
|
||||
SessionTTL: l4SessionIdleTimeout(mapping),
|
||||
SessionTTL: s.clampIdleTimeout(l4SessionIdleTimeout(mapping)),
|
||||
AccessLog: s.accessLog,
|
||||
Filter: parseRestrictions(mapping),
|
||||
Geo: s.geo,
|
||||
})
|
||||
relay.SetObserver(s.meter)
|
||||
|
||||
@@ -1306,9 +1416,15 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
if mapping.GetAuth().GetOidc() {
|
||||
schemes = append(schemes, auth.NewOIDC(s.mgmtClient, svcID, accountID, s.ForwardedProto))
|
||||
}
|
||||
for _, ha := range mapping.GetAuth().GetHeaderAuths() {
|
||||
schemes = append(schemes, auth.NewHeader(s.mgmtClient, svcID, accountID, ha.GetHeader()))
|
||||
}
|
||||
|
||||
ipRestrictions := parseRestrictions(mapping)
|
||||
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
|
||||
|
||||
maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second
|
||||
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID); err != nil {
|
||||
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions); err != nil {
|
||||
return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err)
|
||||
}
|
||||
m := s.protoToMapping(ctx, mapping)
|
||||
@@ -1449,12 +1565,10 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
|
||||
pt.RequestTimeout = d.AsDuration()
|
||||
}
|
||||
}
|
||||
if pt.RequestTimeout == 0 && s.DefaultDialTimeout > 0 {
|
||||
pt.RequestTimeout = s.DefaultDialTimeout
|
||||
}
|
||||
pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout)
|
||||
paths[pathMapping.GetPath()] = pt
|
||||
}
|
||||
return proxy.Mapping{
|
||||
m := proxy.Mapping{
|
||||
ID: types.ServiceID(mapping.GetId()),
|
||||
AccountID: types.AccountID(mapping.GetAccountId()),
|
||||
Host: mapping.GetDomain(),
|
||||
@@ -1462,6 +1576,10 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
|
||||
PassHostHeader: mapping.GetPassHostHeader(),
|
||||
RewriteRedirects: mapping.GetRewriteRedirects(),
|
||||
}
|
||||
for _, ha := range mapping.GetAuth().GetHeaderAuths() {
|
||||
m.StripAuthHeaders = append(m.StripAuthHeaders, ha.GetHeader())
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode {
|
||||
|
||||
Reference in New Issue
Block a user