diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index 9d96678a3..1125f428f 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net" - "net/url" "strings" log "github.com/sirupsen/logrus" @@ -223,57 +222,19 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID } // proxyURLAllowList retrieves a list of currently connected proxies and -// their URLs (as reported by the proxy servers). It performs some clean -// up on those URLs to attempt to retrieve domain names as we would -// expect to see them in a validation check. +// their URLs func (m Manager) proxyURLAllowList() []string { var reverseProxyAddresses []string if m.proxyURLProvider != nil { reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs() } - var allowedProxyURLs []string - for _, addr := range reverseProxyAddresses { - if addr == "" { - continue - } - host := extractHostFromAddress(addr) - if host != "" { - allowedProxyURLs = append(allowedProxyURLs, host) - } - } - return allowedProxyURLs -} - -// extractHostFromAddress extracts the hostname from an address string. -// It handles both URL format (https://host:port) and plain hostname (host or host:port). -func extractHostFromAddress(addr string) string { - // If it looks like a URL with a scheme, parse it - if strings.Contains(addr, "://") { - proxyUrl, err := url.Parse(addr) - if err != nil { - log.WithError(err).Debugf("failed to parse proxy URL %s", addr) - return "" - } - host, _, err := net.SplitHostPort(proxyUrl.Host) - if err != nil { - return proxyUrl.Host - } - return host - } - - // Otherwise treat as hostname or host:port - host, _, err := net.SplitHostPort(addr) - if err != nil { - // No port, use as-is - return addr - } - return host + return reverseProxyAddresses } // DeriveClusterFromDomain determines the proxy cluster for a given domain. // For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain. -// For custom domains, the cluster is determined by looking up the CNAME target. -func (m Manager) DeriveClusterFromDomain(ctx context.Context, domain string) (string, error) { +// For custom domains, the cluster is determined by checking the registered custom domain's target cluster. +func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) { allowList := m.proxyURLAllowList() if len(allowList) == 0 { return "", fmt.Errorf("no proxy clusters available") @@ -283,14 +244,28 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, domain string) (st return cluster, nil } - cluster, valid := m.validator.ValidateWithCluster(ctx, domain, allowList) + customDomains, err := m.store.ListCustomDomains(ctx, accountID) + if err != nil { + return "", fmt.Errorf("list custom domains: %w", err) + } + + targetCluster, valid := extractClusterFromCustomDomains(domain, customDomains) if valid { - return cluster, nil + return targetCluster, nil } return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain) } +func extractClusterFromCustomDomains(domain string, customDomains []*domain.Domain) (string, bool) { + for _, customDomain := range customDomains { + if strings.HasSuffix(domain, "."+customDomain.Domain) { + return customDomain.TargetCluster, true + } + } + return "", false +} + // ExtractClusterFromFreeDomain extracts the cluster address from a free domain. // Free domains have the format: .. (e.g., myapp.abc123.eu.proxy.netbird.io) // It matches the domain suffix against available clusters and returns the matching cluster. diff --git a/management/internals/modules/reverseproxy/manager/manager.go b/management/internals/modules/reverseproxy/manager/manager.go index 12c57cbd0..df0c8fea6 100644 --- a/management/internals/modules/reverseproxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/manager/manager.go @@ -23,7 +23,7 @@ const unknownHostPlaceholder = "unknown" // ClusterDeriver derives the proxy cluster from a domain. type ClusterDeriver interface { - DeriveClusterFromDomain(ctx context.Context, domain string) (string, error) + DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) } type managerImpl struct { @@ -137,7 +137,7 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin var proxyCluster string if m.clusterDeriver != nil { - proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, service.Domain) + proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain) if err != nil { log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain) return nil, status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err) @@ -239,7 +239,7 @@ func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID strin } if m.clusterDeriver != nil { - newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, service.Domain) + newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain) if err != nil { log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain) }