mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
[proxy, management] Add header auth, access restrictions, and session idle timeout (#5587)
This commit is contained in:
@@ -7,12 +7,14 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
@@ -20,6 +22,10 @@ import (
|
||||
// timeout is configured.
|
||||
const defaultDialTimeout = 30 * time.Second
|
||||
|
||||
// errAccessRestricted is returned by relayTCP for access restriction
|
||||
// denials so callers can skip warn-level logging (already logged at debug).
|
||||
var errAccessRestricted = errors.New("rejected by access restrictions")
|
||||
|
||||
// SNIHost is a typed key for SNI hostname lookups.
|
||||
type SNIHost string
|
||||
|
||||
@@ -64,6 +70,11 @@ type Route struct {
|
||||
// DialTimeout overrides the default dial timeout for this route.
|
||||
// Zero uses defaultDialTimeout.
|
||||
DialTimeout time.Duration
|
||||
// SessionIdleTimeout overrides the default idle timeout for relay connections.
|
||||
// Zero uses DefaultIdleTimeout.
|
||||
SessionIdleTimeout time.Duration
|
||||
// Filter holds connection-level IP/geo restrictions. Nil means no restrictions.
|
||||
Filter *restrict.Filter
|
||||
}
|
||||
|
||||
// l4Logger sends layer-4 access log entries to the management server.
|
||||
@@ -99,6 +110,7 @@ type Router struct {
|
||||
drainDone chan struct{}
|
||||
observer RelayObserver
|
||||
accessLog l4Logger
|
||||
geo restrict.GeoResolver
|
||||
// svcCtxs tracks a context per service ID. All relay goroutines for a
|
||||
// service derive from its context; canceling it kills them immediately.
|
||||
svcCtxs map[types.ServiceID]context.Context
|
||||
@@ -144,6 +156,7 @@ func (r *Router) HTTPListener() net.Listener {
|
||||
// stored and resolved by priority at lookup time (HTTP > TCP).
|
||||
// Empty host is ignored to prevent conflicts with ECH/ESNI fallback.
|
||||
func (r *Router) AddRoute(host SNIHost, route Route) {
|
||||
host = SNIHost(strings.ToLower(string(host)))
|
||||
if host == "" {
|
||||
return
|
||||
}
|
||||
@@ -166,6 +179,8 @@ func (r *Router) AddRoute(host SNIHost, route Route) {
|
||||
// Active relay connections for the service are closed immediately.
|
||||
// If other routes remain for the host, they are preserved.
|
||||
func (r *Router) RemoveRoute(host SNIHost, svcID types.ServiceID) {
|
||||
host = SNIHost(strings.ToLower(string(host)))
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
@@ -295,7 +310,7 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
host := SNIHost(sni)
|
||||
host := SNIHost(strings.ToLower(sni))
|
||||
route, ok := r.lookupRoute(host)
|
||||
if !ok {
|
||||
r.handleUnmatched(ctx, wrapped)
|
||||
@@ -308,11 +323,13 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
|
||||
}
|
||||
|
||||
if err := r.relayTCP(ctx, wrapped, host, route); err != nil {
|
||||
r.logger.WithFields(log.Fields{
|
||||
"sni": host,
|
||||
"service_id": route.ServiceID,
|
||||
"target": route.Target,
|
||||
}).Warnf("TCP relay: %v", err)
|
||||
if !errors.Is(err, errAccessRestricted) {
|
||||
r.logger.WithFields(log.Fields{
|
||||
"sni": host,
|
||||
"service_id": route.ServiceID,
|
||||
"target": route.Target,
|
||||
}).Warnf("TCP relay: %v", err)
|
||||
}
|
||||
_ = wrapped.Close()
|
||||
}
|
||||
}
|
||||
@@ -336,10 +353,12 @@ func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
|
||||
|
||||
if fb != nil {
|
||||
if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil {
|
||||
r.logger.WithFields(log.Fields{
|
||||
"service_id": fb.ServiceID,
|
||||
"target": fb.Target,
|
||||
}).Warnf("TCP relay (fallback): %v", err)
|
||||
if !errors.Is(err, errAccessRestricted) {
|
||||
r.logger.WithFields(log.Fields{
|
||||
"service_id": fb.ServiceID,
|
||||
"target": fb.Target,
|
||||
}).Warnf("TCP relay (fallback): %v", err)
|
||||
}
|
||||
_ = conn.Close()
|
||||
}
|
||||
return
|
||||
@@ -427,10 +446,44 @@ func (r *Router) cancelServiceLocked(svcID types.ServiceID) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetGeo sets the geolocation lookup used for country-based restrictions.
|
||||
func (r *Router) SetGeo(geo restrict.GeoResolver) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.geo = geo
|
||||
}
|
||||
|
||||
// checkRestrictions evaluates the route's access filter against the
|
||||
// connection's remote address. Returns Allow if the connection is
|
||||
// permitted, or a deny verdict indicating the reason.
|
||||
func (r *Router) checkRestrictions(conn net.Conn, route Route) restrict.Verdict {
|
||||
if route.Filter == nil {
|
||||
return restrict.Allow
|
||||
}
|
||||
|
||||
addr, err := addrFromConn(conn)
|
||||
if err != nil {
|
||||
r.logger.Debugf("cannot parse client address %s for restriction check, denying", conn.RemoteAddr())
|
||||
return restrict.DenyCIDR
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
geo := r.geo
|
||||
r.mu.RUnlock()
|
||||
|
||||
return route.Filter.Check(addr, geo)
|
||||
}
|
||||
|
||||
// relayTCP sets up and runs a bidirectional TCP relay.
|
||||
// The caller owns conn and must close it if this method returns an error.
|
||||
// On success (nil error), both conn and backend are closed by the relay.
|
||||
func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error {
|
||||
if verdict := r.checkRestrictions(conn, route); verdict != restrict.Allow {
|
||||
r.logger.Debugf("connection from %s rejected by access restrictions: %s", conn.RemoteAddr(), verdict)
|
||||
r.logL4Deny(route, conn, verdict)
|
||||
return errAccessRestricted
|
||||
}
|
||||
|
||||
svcCtx, err := r.acquireRelay(ctx, route)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -468,8 +521,13 @@ func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route
|
||||
})
|
||||
entry.Debug("TCP relay started")
|
||||
|
||||
idleTimeout := route.SessionIdleTimeout
|
||||
if idleTimeout <= 0 {
|
||||
idleTimeout = DefaultIdleTimeout
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
s2d, d2s := Relay(svcCtx, entry, conn, backend, DefaultIdleTimeout)
|
||||
s2d, d2s := Relay(svcCtx, entry, conn, backend, idleTimeout)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if obs != nil {
|
||||
@@ -537,12 +595,7 @@ func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration,
|
||||
return
|
||||
}
|
||||
|
||||
var sourceIP netip.Addr
|
||||
if remote := conn.RemoteAddr(); remote != nil {
|
||||
if ap, err := netip.ParseAddrPort(remote.String()); err == nil {
|
||||
sourceIP = ap.Addr().Unmap()
|
||||
}
|
||||
}
|
||||
sourceIP, _ := addrFromConn(conn)
|
||||
|
||||
al.LogL4(accesslog.L4Entry{
|
||||
AccountID: route.AccountID,
|
||||
@@ -556,6 +609,28 @@ func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration,
|
||||
})
|
||||
}
|
||||
|
||||
// logL4Deny sends an access log entry for a denied connection.
|
||||
func (r *Router) logL4Deny(route Route, conn net.Conn, verdict restrict.Verdict) {
|
||||
r.mu.RLock()
|
||||
al := r.accessLog
|
||||
r.mu.RUnlock()
|
||||
|
||||
if al == nil {
|
||||
return
|
||||
}
|
||||
|
||||
sourceIP, _ := addrFromConn(conn)
|
||||
|
||||
al.LogL4(accesslog.L4Entry{
|
||||
AccountID: route.AccountID,
|
||||
ServiceID: route.ServiceID,
|
||||
Protocol: route.Protocol,
|
||||
Host: route.Domain,
|
||||
SourceIP: sourceIP,
|
||||
DenyReason: verdict.String(),
|
||||
})
|
||||
}
|
||||
|
||||
// getOrCreateServiceCtxLocked returns the context for a service, creating one
|
||||
// if it doesn't exist yet. The context is a child of the server context.
|
||||
// Must be called with mu held.
|
||||
@@ -568,3 +643,16 @@ func (r *Router) getOrCreateServiceCtxLocked(parent context.Context, svcID types
|
||||
r.svcCancels[svcID] = cancel
|
||||
return ctx
|
||||
}
|
||||
|
||||
// addrFromConn extracts a netip.Addr from a connection's remote address.
|
||||
func addrFromConn(conn net.Conn) (netip.Addr, error) {
|
||||
remote := conn.RemoteAddr()
|
||||
if remote == nil {
|
||||
return netip.Addr{}, errors.New("no remote address")
|
||||
}
|
||||
ap, err := netip.ParseAddrPort(remote.String())
|
||||
if err != nil {
|
||||
return netip.Addr{}, err
|
||||
}
|
||||
return ap.Addr().Unmap(), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user