Merge remote-tracking branch 'origin/prototype/reverse-proxy-clusters' into prototype/reverse-proxy

# Conflicts:
#	management/internals/modules/reverseproxy/manager/manager.go
#	management/internals/modules/reverseproxy/reverseproxy.go
#	management/internals/server/modules.go
#	management/internals/shared/grpc/proxy.go
#	management/server/http/handler.go
#	management/server/http/testing/testing_tools/channel/channel.go
This commit is contained in:
pascal
2026-02-05 15:19:57 +01:00
22 changed files with 466 additions and 126 deletions

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/url"
"strings"
"github.com/netbirdio/netbird/management/server/types"
log "github.com/sirupsen/logrus"
@@ -57,21 +58,31 @@ func NewManager(store store, proxyURLProvider proxyURLProvider) Manager {
}
func (m Manager) GetDomains(ctx context.Context, accountID string) ([]*Domain, error) {
account, err := m.store.GetAccount(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get account: %w", err)
}
free, err := m.store.ListFreeDomains(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("list free domains: %w", err)
}
domains, err := m.store.ListCustomDomains(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("list custom domains: %w", err)
}
var ret []*Domain
// Populate all fields correctly for custom domains that are retrieved.
// Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
allowList := m.proxyURLAllowList()
log.WithFields(log.Fields{
"accountID": accountID,
"proxyAllowList": allowList,
}).Debug("getting domains with proxy allow list")
for _, cluster := range allowList {
ret = append(ret, &Domain{
Domain: cluster,
AccountID: accountID,
Type: TypeFree,
Validated: true,
})
}
// Add custom domains.
for _, domain := range domains {
ret = append(ret, &Domain{
ID: domain.ID,
@@ -82,19 +93,6 @@ func (m Manager) GetDomains(ctx context.Context, accountID string) ([]*Domain, e
})
}
// Prepend each free domain with the account nonce and then add it to the domain
// array to be returned.
// This account nonce is added to free domains to prevent users being able to
// query free domain usage across accounts and simplifies tracking free domain
// usage across accounts.
for _, name := range free {
ret = append(ret, &Domain{
Domain: account.ReverseProxyFreeDomainNonce + "." + name,
AccountID: accountID,
Type: TypeFree,
Validated: true,
})
}
return ret, nil
}
@@ -124,6 +122,11 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, domainID string) e
}
func (m Manager) ValidateDomain(accountID, domainID string) {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).Info("starting domain validation")
d, err := m.store.GetCustomDomain(context.Background(), accountID, domainID)
if err != nil {
log.WithFields(log.Fields{
@@ -133,12 +136,20 @@ func (m Manager) ValidateDomain(accountID, domainID string) {
return
}
if m.validator.IsValid(context.Background(), d.Domain, m.proxyURLAllowList()) {
allowList := m.proxyURLAllowList()
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
"proxyAllowList": allowList,
}).Info("validating domain against proxy allow list")
if m.validator.IsValid(context.Background(), d.Domain, allowList) {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
}).Debug("domain validated successfully")
}).Info("domain validated successfully")
d.Validated = true
if _, err := m.store.UpdateCustomDomain(context.Background(), accountID, d); err != nil {
log.WithFields(log.Fields{
@@ -148,6 +159,13 @@ func (m Manager) ValidateDomain(accountID, domainID string) {
}).WithError(err).Error("update custom domain in store")
return
}
} else {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
"proxyAllowList": allowList,
}).Warn("domain validation failed - CNAME does not match any connected proxy")
}
}
@@ -162,17 +180,60 @@ func (m Manager) proxyURLAllowList() []string {
}
var allowedProxyURLs []string
for _, addr := range reverseProxyAddresses {
proxyUrl, err := url.Parse(addr)
if err != nil {
// TODO: log?
if addr == "" {
continue
}
host, _, err := net.SplitHostPort(proxyUrl.Host)
if err != nil {
// TODO: log?
host = proxyUrl.Host
host := extractHostFromAddress(addr)
if host != "" {
allowedProxyURLs = append(allowedProxyURLs, host)
}
allowedProxyURLs = append(allowedProxyURLs, host)
}
return allowedProxyURLs
}
// extractHostFromAddress extracts the hostname from an address string.
// It handles both URL format (https://host:port) and plain hostname (host or host:port).
func extractHostFromAddress(addr string) string {
// If it looks like a URL with a scheme, parse it
if strings.Contains(addr, "://") {
proxyUrl, err := url.Parse(addr)
if err != nil {
log.WithError(err).Debugf("failed to parse proxy URL %s", addr)
return ""
}
host, _, err := net.SplitHostPort(proxyUrl.Host)
if err != nil {
return proxyUrl.Host
}
return host
}
// Otherwise treat as hostname or host:port
host, _, err := net.SplitHostPort(addr)
if err != nil {
// No port, use as-is
return addr
}
return host
}
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by looking up the CNAME target.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, domain string) (string, error) {
allowList := m.proxyURLAllowList()
if len(allowList) == 0 {
return "", fmt.Errorf("no proxy clusters available")
}
if cluster, ok := ExtractClusterFromFreeDomain(domain, allowList); ok {
return cluster, nil
}
cluster, valid := m.validator.ValidateWithCluster(ctx, domain, allowList)
if valid {
return cluster, nil
}
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
}

View File

@@ -32,28 +32,69 @@ func NewValidator(resolver resolver) *Validator {
// The comparison is very simple, so wildcards will not match if included
// in the acceptable domain list.
func (v *Validator) IsValid(ctx context.Context, domain string, accept []string) bool {
_, valid := v.ValidateWithCluster(ctx, domain, accept)
return valid
}
// ValidateWithCluster validates a custom domain and returns the matched cluster address.
// Returns the cluster address and true if valid, or empty string and false if invalid.
func (v *Validator) ValidateWithCluster(ctx context.Context, domain string, accept []string) (string, bool) {
if v.resolver == nil {
v.resolver = net.DefaultResolver
}
// Prepend subdomain for ownership validation because we want to check
// for the record being a wildcard ("*.example.com"), but you cannot
// look up a wildcard so we have to add a subdomain for the check.
cname, err := v.resolver.LookupCNAME(ctx, "validation."+domain)
lookupDomain := "validation." + domain
log.WithFields(log.Fields{
"domain": domain,
"lookupDomain": lookupDomain,
"acceptList": accept,
}).Debug("looking up CNAME for domain validation")
cname, err := v.resolver.LookupCNAME(ctx, lookupDomain)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
}).WithError(err).Error("Error resolving CNAME from resolver")
return false
"domain": domain,
"lookupDomain": lookupDomain,
}).WithError(err).Warn("CNAME lookup failed for domain validation")
return "", false
}
// Remove a trailing "." from the CNAME (most people do not include the trailing "." in FQDN, so it is easier to strip this when comparing).
nakedCNAME := strings.TrimSuffix(cname, ".")
for _, domain := range accept {
// Currently, the match is a very simple string comparison.
if nakedCNAME == strings.TrimSuffix(domain, ".") {
return true
log.WithFields(log.Fields{
"domain": domain,
"cname": cname,
"nakedCNAME": nakedCNAME,
"acceptList": accept,
}).Debug("CNAME lookup result for domain validation")
for _, acceptDomain := range accept {
normalizedAccept := strings.TrimSuffix(acceptDomain, ".")
if nakedCNAME == normalizedAccept {
log.WithFields(log.Fields{
"domain": domain,
"cname": nakedCNAME,
"cluster": acceptDomain,
}).Info("domain CNAME matched cluster")
return acceptDomain, true
}
}
return false
log.WithFields(log.Fields{
"domain": domain,
"cname": nakedCNAME,
"acceptList": accept,
}).Warn("domain CNAME does not match any accepted cluster")
return "", false
}
// ExtractClusterFromFreeDomain extracts the cluster address from a free domain.
// Free domains have the format: <name>.<nonce>.<cluster> (e.g., myapp.abc123.eu.proxy.netbird.io)
// It matches the domain suffix against available clusters and returns the matching cluster.
func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) {
for _, cluster := range availableClusters {
if strings.HasSuffix(domain, "."+cluster) {
return cluster, true
}
}
return "", false
}

View File

@@ -20,12 +20,12 @@ type handler struct {
manager reverseproxy.Manager
}
// RegisterEndpoints registers all reverse proxy HTTP endpoints.
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domain.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
h := &handler{
manager: manager,
}
// Hang domain endpoints off the main router here.
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
domain.RegisterEndpoints(domainRouter, domainManager)

View File

@@ -5,6 +5,10 @@ import (
"fmt"
"time"
"github.com/google/uuid"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
@@ -17,21 +21,29 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
// ClusterDeriver derives the proxy cluster from a domain.
type ClusterDeriver interface {
DeriveClusterFromDomain(ctx context.Context, domain string) (string, error)
}
type managerImpl struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
proxyGRPCServer *nbgrpc.ProxyServiceServer
tokenStore *nbgrpc.OneTimeTokenStore
clusterDeriver ClusterDeriver
}
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, tokenStore *nbgrpc.OneTimeTokenStore) reverseproxy.Manager {
// NewManager creates a new reverse proxy manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, tokenStore *nbgrpc.OneTimeTokenStore, clusterDeriver ClusterDeriver) reverseproxy.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
proxyGRPCServer: proxyGRPCServer,
tokenStore: tokenStore,
clusterDeriver: clusterDeriver,
}
}
@@ -68,9 +80,17 @@ func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID
return nil, status.NewPermissionDeniedError()
}
var proxyCluster string
if m.clusterDeriver != nil {
proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, reverseProxy.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxies", reverseProxy.Domain)
}
}
authConfig := reverseProxy.Auth
reverseProxy = reverseproxy.NewReverseProxy(accountID, reverseProxy.Name, reverseProxy.Domain, reverseProxy.Targets, reverseProxy.Enabled)
reverseProxy = reverseproxy.NewReverseProxy(accountID, reverseProxy.Name, reverseProxy.Domain, proxyCluster, reverseProxy.Targets, reverseProxy.Enabled)
reverseProxy.Auth = authConfig
@@ -111,7 +131,7 @@ func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID
m.accountManager.StoreEvent(ctx, userID, reverseProxy.ID, accountID, activity.ReverseProxyCreated, reverseProxy.EventMeta())
m.proxyGRPCServer.SendReverseProxyUpdate(reverseProxy.ToProtoMapping(reverseproxy.Create, token, m.proxyGRPCServer.GetOIDCValidationConfig()))
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Create, token), m.proxyGRPCServer.GetOIDCValidationConfig(), reverseProxy.ProxyCluster))
return reverseProxy, nil
}
@@ -125,28 +145,44 @@ func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID
return nil, status.NewPermissionDeniedError()
}
var oldCluster string
var domainChanged bool
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Get existing reverse proxy
existingReverseProxy, err := transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxy.ID)
if err != nil {
return err
}
// Check if domain changed and if it conflicts
oldCluster = existingReverseProxy.ProxyCluster
if existingReverseProxy.Domain != reverseProxy.Domain {
domainChanged = true
conflictReverseProxy, err := transaction.GetReverseProxyByDomain(ctx, accountID, reverseProxy.Domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing reverse proxy: %w", err)
return fmt.Errorf("check existing reverse proxy: %w", err)
}
}
if conflictReverseProxy != nil && conflictReverseProxy.ID != reverseProxy.ID {
return status.Errorf(status.AlreadyExists, "reverse proxy with domain %s already exists", reverseProxy.Domain)
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, reverseProxy.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", reverseProxy.Domain)
}
reverseProxy.ProxyCluster = newCluster
}
} else {
reverseProxy.ProxyCluster = existingReverseProxy.ProxyCluster
}
reverseProxy.Meta = existingReverseProxy.Meta
if err = transaction.UpdateReverseProxy(ctx, reverseProxy); err != nil {
return fmt.Errorf("failed to update reverse proxy: %w", err)
return fmt.Errorf("update reverse proxy: %w", err)
}
return nil
@@ -157,7 +193,12 @@ func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID
m.accountManager.StoreEvent(ctx, userID, reverseProxy.ID, accountID, activity.ReverseProxyUpdated, reverseProxy.EventMeta())
m.proxyGRPCServer.SendReverseProxyUpdate(reverseProxy.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()))
if domainChanged && oldCluster != reverseProxy.ProxyCluster {
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Delete, ""), oldCluster, m.proxyGRPCServer.GetOIDCValidationConfig())
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Create, ""), reverseProxy.ProxyCluster, m.proxyGRPCServer.GetOIDCValidationConfig())
} else {
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Update, ""), reverseProxy.ProxyCluster, m.proxyGRPCServer.GetOIDCValidationConfig())
}
return reverseProxy, nil
}
@@ -191,7 +232,7 @@ func (m *managerImpl) DeleteReverseProxy(ctx context.Context, accountID, userID,
m.accountManager.StoreEvent(ctx, userID, reverseProxyID, accountID, activity.ReverseProxyDeleted, reverseProxy.EventMeta())
m.proxyGRPCServer.SendReverseProxyUpdate(reverseProxy.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()))
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Delete, ""), reverseProxy.ProxyCluster, m.proxyGRPCServer.GetOIDCValidationConfig())
return nil
}

View File

@@ -82,26 +82,28 @@ type ReverseProxyMeta struct {
}
type ReverseProxy struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string `gorm:"index"`
Targets []Target `gorm:"serializer:json"`
Enabled bool
Auth AuthConfig `gorm:"serializer:json"`
Meta ReverseProxyMeta `gorm:"embedded;embeddedPrefix:meta_"`
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string `gorm:"index"`
ProxyCluster string `gorm:"index"`
Targets []Target `gorm:"serializer:json"`
Enabled bool
Auth AuthConfig `gorm:"serializer:json"`
Meta ReverseProxyMeta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
}
func NewReverseProxy(accountID, name, domain string, targets []Target, enabled bool) *ReverseProxy {
func NewReverseProxy(accountID, name, domain, proxyCluster string, targets []Target, enabled bool) *ReverseProxy {
return &ReverseProxy{
ID: xid.New().String(),
AccountID: accountID,
Name: name,
Domain: domain,
Targets: targets,
Enabled: enabled,
ID: xid.New().String(),
AccountID: accountID,
Name: name,
Domain: domain,
ProxyCluster: proxyCluster,
Targets: targets,
Enabled: enabled,
Meta: ReverseProxyMeta{
CreatedAt: time.Now(),
Status: string(StatusPending),
@@ -156,7 +158,7 @@ func (r *ReverseProxy) ToAPIResponse() *api.ReverseProxy {
meta.CertificateIssuedAt = &r.Meta.CertificateIssuedAt
}
return &api.ReverseProxy{
resp := &api.ReverseProxy{
Id: r.ID,
Name: r.Name,
Domain: r.Domain,
@@ -165,6 +167,12 @@ func (r *ReverseProxy) ToAPIResponse() *api.ReverseProxy {
Auth: authConfig,
Meta: meta,
}
if r.ProxyCluster != "" {
resp.ProxyCluster = &r.ProxyCluster
}
return resp
}
func (r *ReverseProxy) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
@@ -302,7 +310,7 @@ func (r *ReverseProxy) Validate() error {
}
func (r *ReverseProxy) EventMeta() map[string]any {
return map[string]any{"name": r.Name, "domain": r.Domain}
return map[string]any{"name": r.Name, "domain": r.Domain, "proxy_cluster": r.ProxyCluster}
}
func (r *ReverseProxy) Copy() *ReverseProxy {