mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
[proxy, management] Add header auth, access restrictions, and session idle timeout (#5587)
This commit is contained in:
@@ -7,14 +7,15 @@ import (
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
@@ -91,10 +92,37 @@ type BearerAuthConfig struct {
|
||||
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// HeaderAuthConfig defines a static header-value auth check.
|
||||
// The proxy compares the incoming header value against the stored hash.
|
||||
type HeaderAuthConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Header string `json:"header"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
|
||||
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
|
||||
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
|
||||
HeaderAuths []*HeaderAuthConfig `json:"header_auths,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// AccessRestrictions controls who can connect to the service based on IP or geography.
|
||||
type AccessRestrictions struct {
|
||||
AllowedCIDRs []string `json:"allowed_cidrs,omitempty" gorm:"serializer:json"`
|
||||
BlockedCIDRs []string `json:"blocked_cidrs,omitempty" gorm:"serializer:json"`
|
||||
AllowedCountries []string `json:"allowed_countries,omitempty" gorm:"serializer:json"`
|
||||
BlockedCountries []string `json:"blocked_countries,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the AccessRestrictions.
|
||||
func (r AccessRestrictions) Copy() AccessRestrictions {
|
||||
return AccessRestrictions{
|
||||
AllowedCIDRs: slices.Clone(r.AllowedCIDRs),
|
||||
BlockedCIDRs: slices.Clone(r.BlockedCIDRs),
|
||||
AllowedCountries: slices.Clone(r.AllowedCountries),
|
||||
BlockedCountries: slices.Clone(r.BlockedCountries),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AuthConfig) HashSecrets() error {
|
||||
@@ -114,6 +142,16 @@ func (a *AuthConfig) HashSecrets() error {
|
||||
a.PinAuth.Pin = hashedPin
|
||||
}
|
||||
|
||||
for i, h := range a.HeaderAuths {
|
||||
if h != nil && h.Enabled && h.Value != "" {
|
||||
hashedValue, err := argon2id.Hash(h.Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash header auth[%d] value: %w", i, err)
|
||||
}
|
||||
h.Value = hashedValue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -124,6 +162,11 @@ func (a *AuthConfig) ClearSecrets() {
|
||||
if a.PinAuth != nil {
|
||||
a.PinAuth.Pin = ""
|
||||
}
|
||||
for _, h := range a.HeaderAuths {
|
||||
if h != nil {
|
||||
h.Value = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Meta struct {
|
||||
@@ -143,12 +186,13 @@ type Service struct {
|
||||
Enabled bool
|
||||
PassHostHeader bool
|
||||
RewriteRedirects bool
|
||||
Auth AuthConfig `gorm:"serializer:json"`
|
||||
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
|
||||
SourcePeer string `gorm:"index:idx_service_source_peer"`
|
||||
Auth AuthConfig `gorm:"serializer:json"`
|
||||
Restrictions AccessRestrictions `gorm:"serializer:json"`
|
||||
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
|
||||
SourcePeer string `gorm:"index:idx_service_source_peer"`
|
||||
// Mode determines the service type: "http", "tcp", "udp", or "tls".
|
||||
Mode string `gorm:"default:'http'"`
|
||||
ListenPort uint16
|
||||
@@ -188,6 +232,20 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
}
|
||||
}
|
||||
|
||||
if len(s.Auth.HeaderAuths) > 0 {
|
||||
apiHeaders := make([]api.HeaderAuthConfig, 0, len(s.Auth.HeaderAuths))
|
||||
for _, h := range s.Auth.HeaderAuths {
|
||||
if h == nil {
|
||||
continue
|
||||
}
|
||||
apiHeaders = append(apiHeaders, api.HeaderAuthConfig{
|
||||
Enabled: h.Enabled,
|
||||
Header: h.Header,
|
||||
})
|
||||
}
|
||||
authConfig.HeaderAuths = &apiHeaders
|
||||
}
|
||||
|
||||
// Convert internal targets to API targets
|
||||
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
|
||||
for _, target := range s.Targets {
|
||||
@@ -222,18 +280,19 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
listenPort := int(s.ListenPort)
|
||||
|
||||
resp := &api.Service{
|
||||
Id: s.ID,
|
||||
Name: s.Name,
|
||||
Domain: s.Domain,
|
||||
Targets: apiTargets,
|
||||
Enabled: s.Enabled,
|
||||
PassHostHeader: &s.PassHostHeader,
|
||||
RewriteRedirects: &s.RewriteRedirects,
|
||||
Auth: authConfig,
|
||||
Meta: meta,
|
||||
Mode: &mode,
|
||||
ListenPort: &listenPort,
|
||||
PortAutoAssigned: &s.PortAutoAssigned,
|
||||
Id: s.ID,
|
||||
Name: s.Name,
|
||||
Domain: s.Domain,
|
||||
Targets: apiTargets,
|
||||
Enabled: s.Enabled,
|
||||
PassHostHeader: &s.PassHostHeader,
|
||||
RewriteRedirects: &s.RewriteRedirects,
|
||||
Auth: authConfig,
|
||||
AccessRestrictions: restrictionsToAPI(s.Restrictions),
|
||||
Meta: meta,
|
||||
Mode: &mode,
|
||||
ListenPort: &listenPort,
|
||||
PortAutoAssigned: &s.PortAutoAssigned,
|
||||
}
|
||||
|
||||
if s.ProxyCluster != "" {
|
||||
@@ -263,7 +322,16 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
|
||||
auth.Oidc = true
|
||||
}
|
||||
|
||||
return &proto.ProxyMapping{
|
||||
for _, h := range s.Auth.HeaderAuths {
|
||||
if h != nil && h.Enabled {
|
||||
auth.HeaderAuths = append(auth.HeaderAuths, &proto.HeaderAuth{
|
||||
Header: h.Header,
|
||||
HashedValue: h.Value,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
mapping := &proto.ProxyMapping{
|
||||
Type: operationToProtoType(operation),
|
||||
Id: s.ID,
|
||||
Domain: s.Domain,
|
||||
@@ -276,6 +344,12 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
|
||||
Mode: s.Mode,
|
||||
ListenPort: int32(s.ListenPort), //nolint:gosec
|
||||
}
|
||||
|
||||
if r := restrictionsToProto(s.Restrictions); r != nil {
|
||||
mapping.AccessRestrictions = r
|
||||
}
|
||||
|
||||
return mapping
|
||||
}
|
||||
|
||||
// buildPathMappings constructs PathMapping entries from targets.
|
||||
@@ -334,8 +408,7 @@ func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
|
||||
case Delete:
|
||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
|
||||
default:
|
||||
log.Fatalf("unknown operation type: %v", op)
|
||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||
panic(fmt.Sprintf("unknown operation type: %v", op))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -477,6 +550,10 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
|
||||
s.Auth = authFromAPI(req.Auth)
|
||||
}
|
||||
|
||||
if req.AccessRestrictions != nil {
|
||||
s.Restrictions = restrictionsFromAPI(req.AccessRestrictions)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -538,9 +615,70 @@ func authFromAPI(reqAuth *api.ServiceAuthConfig) AuthConfig {
|
||||
}
|
||||
auth.BearerAuth = bearerAuth
|
||||
}
|
||||
if reqAuth.HeaderAuths != nil {
|
||||
for _, h := range *reqAuth.HeaderAuths {
|
||||
auth.HeaderAuths = append(auth.HeaderAuths, &HeaderAuthConfig{
|
||||
Enabled: h.Enabled,
|
||||
Header: h.Header,
|
||||
Value: h.Value,
|
||||
})
|
||||
}
|
||||
}
|
||||
return auth
|
||||
}
|
||||
|
||||
func restrictionsFromAPI(r *api.AccessRestrictions) AccessRestrictions {
|
||||
if r == nil {
|
||||
return AccessRestrictions{}
|
||||
}
|
||||
var res AccessRestrictions
|
||||
if r.AllowedCidrs != nil {
|
||||
res.AllowedCIDRs = *r.AllowedCidrs
|
||||
}
|
||||
if r.BlockedCidrs != nil {
|
||||
res.BlockedCIDRs = *r.BlockedCidrs
|
||||
}
|
||||
if r.AllowedCountries != nil {
|
||||
res.AllowedCountries = *r.AllowedCountries
|
||||
}
|
||||
if r.BlockedCountries != nil {
|
||||
res.BlockedCountries = *r.BlockedCountries
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func restrictionsToAPI(r AccessRestrictions) *api.AccessRestrictions {
|
||||
if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 {
|
||||
return nil
|
||||
}
|
||||
res := &api.AccessRestrictions{}
|
||||
if len(r.AllowedCIDRs) > 0 {
|
||||
res.AllowedCidrs = &r.AllowedCIDRs
|
||||
}
|
||||
if len(r.BlockedCIDRs) > 0 {
|
||||
res.BlockedCidrs = &r.BlockedCIDRs
|
||||
}
|
||||
if len(r.AllowedCountries) > 0 {
|
||||
res.AllowedCountries = &r.AllowedCountries
|
||||
}
|
||||
if len(r.BlockedCountries) > 0 {
|
||||
res.BlockedCountries = &r.BlockedCountries
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func restrictionsToProto(r AccessRestrictions) *proto.AccessRestrictions {
|
||||
if len(r.AllowedCIDRs) == 0 && len(r.BlockedCIDRs) == 0 && len(r.AllowedCountries) == 0 && len(r.BlockedCountries) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &proto.AccessRestrictions{
|
||||
AllowedCidrs: r.AllowedCIDRs,
|
||||
BlockedCidrs: r.BlockedCIDRs,
|
||||
AllowedCountries: r.AllowedCountries,
|
||||
BlockedCountries: r.BlockedCountries,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Validate() error {
|
||||
if s.Name == "" {
|
||||
return errors.New("service name is required")
|
||||
@@ -557,6 +695,13 @@ func (s *Service) Validate() error {
|
||||
s.Mode = ModeHTTP
|
||||
}
|
||||
|
||||
if err := validateHeaderAuths(s.Auth.HeaderAuths); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch s.Mode {
|
||||
case ModeHTTP:
|
||||
return s.validateHTTPMode()
|
||||
@@ -657,6 +802,21 @@ func (s *Service) validateL4Target(target *Target) error {
|
||||
if target.Path != nil && *target.Path != "" && *target.Path != "/" {
|
||||
return errors.New("path is not supported for L4 services")
|
||||
}
|
||||
if target.Options.SessionIdleTimeout < 0 {
|
||||
return errors.New("session_idle_timeout must be positive for L4 services")
|
||||
}
|
||||
if target.Options.RequestTimeout < 0 {
|
||||
return errors.New("request_timeout must be positive for L4 services")
|
||||
}
|
||||
if target.Options.SkipTLSVerify {
|
||||
return errors.New("skip_tls_verify is not supported for L4 services")
|
||||
}
|
||||
if target.Options.PathRewrite != "" {
|
||||
return errors.New("path_rewrite is not supported for L4 services")
|
||||
}
|
||||
if len(target.Options.CustomHeaders) > 0 {
|
||||
return errors.New("custom_headers is not supported for L4 services")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -688,11 +848,9 @@ func IsPortBasedProtocol(mode string) bool {
|
||||
}
|
||||
|
||||
const (
|
||||
maxRequestTimeout = 5 * time.Minute
|
||||
maxSessionIdleTimeout = 10 * time.Minute
|
||||
maxCustomHeaders = 16
|
||||
maxHeaderKeyLen = 128
|
||||
maxHeaderValueLen = 4096
|
||||
maxCustomHeaders = 16
|
||||
maxHeaderKeyLen = 128
|
||||
maxHeaderValueLen = 4096
|
||||
)
|
||||
|
||||
// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
|
||||
@@ -731,22 +889,12 @@ func validateTargetOptions(idx int, opts *TargetOptions) error {
|
||||
return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite)
|
||||
}
|
||||
|
||||
if opts.RequestTimeout != 0 {
|
||||
if opts.RequestTimeout <= 0 {
|
||||
return fmt.Errorf("target %d: request_timeout must be positive", idx)
|
||||
}
|
||||
if opts.RequestTimeout > maxRequestTimeout {
|
||||
return fmt.Errorf("target %d: request_timeout exceeds maximum of %s", idx, maxRequestTimeout)
|
||||
}
|
||||
if opts.RequestTimeout < 0 {
|
||||
return fmt.Errorf("target %d: request_timeout must be positive", idx)
|
||||
}
|
||||
|
||||
if opts.SessionIdleTimeout != 0 {
|
||||
if opts.SessionIdleTimeout <= 0 {
|
||||
return fmt.Errorf("target %d: session_idle_timeout must be positive", idx)
|
||||
}
|
||||
if opts.SessionIdleTimeout > maxSessionIdleTimeout {
|
||||
return fmt.Errorf("target %d: session_idle_timeout exceeds maximum of %s", idx, maxSessionIdleTimeout)
|
||||
}
|
||||
if opts.SessionIdleTimeout < 0 {
|
||||
return fmt.Errorf("target %d: session_idle_timeout must be positive", idx)
|
||||
}
|
||||
|
||||
if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
|
||||
@@ -796,6 +944,93 @@ func containsCRLF(s string) bool {
|
||||
return strings.ContainsAny(s, "\r\n")
|
||||
}
|
||||
|
||||
func validateHeaderAuths(headers []*HeaderAuthConfig) error {
|
||||
seen := make(map[string]struct{})
|
||||
for i, h := range headers {
|
||||
if h == nil || !h.Enabled {
|
||||
continue
|
||||
}
|
||||
if h.Header == "" {
|
||||
return fmt.Errorf("header_auths[%d]: header name is required", i)
|
||||
}
|
||||
if !httpHeaderNameRe.MatchString(h.Header) {
|
||||
return fmt.Errorf("header_auths[%d]: header name %q is not a valid HTTP header name", i, h.Header)
|
||||
}
|
||||
canonical := http.CanonicalHeaderKey(h.Header)
|
||||
if _, ok := hopByHopHeaders[canonical]; ok {
|
||||
return fmt.Errorf("header_auths[%d]: header %q is a hop-by-hop header and cannot be used for auth", i, h.Header)
|
||||
}
|
||||
if _, ok := reservedHeaders[canonical]; ok {
|
||||
return fmt.Errorf("header_auths[%d]: header %q is managed by the proxy and cannot be used for auth", i, h.Header)
|
||||
}
|
||||
if canonical == "Host" {
|
||||
return fmt.Errorf("header_auths[%d]: Host header cannot be used for auth", i)
|
||||
}
|
||||
if _, dup := seen[canonical]; dup {
|
||||
return fmt.Errorf("header_auths[%d]: duplicate header %q (same canonical form already configured)", i, h.Header)
|
||||
}
|
||||
seen[canonical] = struct{}{}
|
||||
if len(h.Value) > maxHeaderValueLen {
|
||||
return fmt.Errorf("header_auths[%d]: value exceeds maximum length of %d", i, maxHeaderValueLen)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
maxCIDREntries = 200
|
||||
maxCountryEntries = 50
|
||||
)
|
||||
|
||||
// validateAccessRestrictions validates and normalizes access restriction
|
||||
// entries. Country codes are uppercased in place.
|
||||
func validateAccessRestrictions(r *AccessRestrictions) error {
|
||||
if len(r.AllowedCIDRs) > maxCIDREntries {
|
||||
return fmt.Errorf("allowed_cidrs: exceeds maximum of %d entries", maxCIDREntries)
|
||||
}
|
||||
if len(r.BlockedCIDRs) > maxCIDREntries {
|
||||
return fmt.Errorf("blocked_cidrs: exceeds maximum of %d entries", maxCIDREntries)
|
||||
}
|
||||
if len(r.AllowedCountries) > maxCountryEntries {
|
||||
return fmt.Errorf("allowed_countries: exceeds maximum of %d entries", maxCountryEntries)
|
||||
}
|
||||
if len(r.BlockedCountries) > maxCountryEntries {
|
||||
return fmt.Errorf("blocked_countries: exceeds maximum of %d entries", maxCountryEntries)
|
||||
}
|
||||
|
||||
for i, raw := range r.AllowedCIDRs {
|
||||
prefix, err := netip.ParsePrefix(raw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("allowed_cidrs[%d]: %w", i, err)
|
||||
}
|
||||
if prefix != prefix.Masked() {
|
||||
return fmt.Errorf("allowed_cidrs[%d]: %q has host bits set, use %s instead", i, raw, prefix.Masked())
|
||||
}
|
||||
}
|
||||
for i, raw := range r.BlockedCIDRs {
|
||||
prefix, err := netip.ParsePrefix(raw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("blocked_cidrs[%d]: %w", i, err)
|
||||
}
|
||||
if prefix != prefix.Masked() {
|
||||
return fmt.Errorf("blocked_cidrs[%d]: %q has host bits set, use %s instead", i, raw, prefix.Masked())
|
||||
}
|
||||
}
|
||||
for i, code := range r.AllowedCountries {
|
||||
if len(code) != 2 {
|
||||
return fmt.Errorf("allowed_countries[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", i, code)
|
||||
}
|
||||
r.AllowedCountries[i] = strings.ToUpper(code)
|
||||
}
|
||||
for i, code := range r.BlockedCountries {
|
||||
if len(code) != 2 {
|
||||
return fmt.Errorf("blocked_countries[%d]: %q must be a 2-letter ISO 3166-1 alpha-2 code", i, code)
|
||||
}
|
||||
r.BlockedCountries[i] = strings.ToUpper(code)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) EventMeta() map[string]any {
|
||||
meta := map[string]any{
|
||||
"name": s.Name,
|
||||
@@ -827,9 +1062,17 @@ func (s *Service) EventMeta() map[string]any {
|
||||
}
|
||||
|
||||
func (s *Service) isAuthEnabled() bool {
|
||||
return (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) ||
|
||||
if (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) ||
|
||||
(s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled) ||
|
||||
(s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled)
|
||||
(s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled) {
|
||||
return true
|
||||
}
|
||||
for _, h := range s.Auth.HeaderAuths {
|
||||
if h != nil && h.Enabled {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Service) Copy() *Service {
|
||||
@@ -866,6 +1109,16 @@ func (s *Service) Copy() *Service {
|
||||
}
|
||||
authCopy.BearerAuth = &ba
|
||||
}
|
||||
if len(s.Auth.HeaderAuths) > 0 {
|
||||
authCopy.HeaderAuths = make([]*HeaderAuthConfig, len(s.Auth.HeaderAuths))
|
||||
for i, h := range s.Auth.HeaderAuths {
|
||||
if h == nil {
|
||||
continue
|
||||
}
|
||||
hCopy := *h
|
||||
authCopy.HeaderAuths[i] = &hCopy
|
||||
}
|
||||
}
|
||||
|
||||
return &Service{
|
||||
ID: s.ID,
|
||||
@@ -878,6 +1131,7 @@ func (s *Service) Copy() *Service {
|
||||
PassHostHeader: s.PassHostHeader,
|
||||
RewriteRedirects: s.RewriteRedirects,
|
||||
Auth: authCopy,
|
||||
Restrictions: s.Restrictions.Copy(),
|
||||
Meta: s.Meta,
|
||||
SessionPrivateKey: s.SessionPrivateKey,
|
||||
SessionPublicKey: s.SessionPublicKey,
|
||||
|
||||
Reference in New Issue
Block a user