rename url flag to domain and update validation

This commit is contained in:
mlsmaycon
2026-02-12 16:28:29 +01:00
parent 41a5509ce0
commit ac995bae6d
3 changed files with 22 additions and 33 deletions

View File

@@ -9,13 +9,13 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"net"
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/netbirdio/netbird/shared/management/domain"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@@ -167,12 +167,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
} }
proxyAddress := req.GetAddress() proxyAddress := req.GetAddress()
log.WithFields(log.Fields{ if !isProxyAddressValid(proxyAddress) {
"proxy_id": proxyID, return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
"address": proxyAddress, }
"version": req.GetVersion(),
"started": req.GetStartedAt().AsTime(),
}).Info("Proxy connected")
connCtx, cancel := context.WithCancel(ctx) connCtx, cancel := context.WithCancel(ctx)
conn := &proxyConnection{ conn := &proxyConnection{
@@ -189,7 +186,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"proxy_id": proxyID, "proxy_id": proxyID,
"address": proxyAddress, "address": proxyAddress,
"cluster_addr": extractClusterAddr(proxyAddress), "cluster_addr": proxyAddress,
"total_proxies": len(s.GetConnectedProxies()), "total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster") }).Info("Proxy registered in cluster")
defer func() { defer func() {
@@ -222,14 +219,16 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return fmt.Errorf("get services from store: %w", err) return fmt.Errorf("get services from store: %w", err)
} }
proxyClusterAddr := extractClusterAddr(conn.address) if !isProxyAddressValid(conn.address) {
return fmt.Errorf("proxy address is invalid")
}
var filtered []*reverseproxy.Service var filtered []*reverseproxy.Service
for _, service := range services { for _, service := range services {
if !service.Enabled { if !service.Enabled {
continue continue
} }
if service.ProxyCluster != "" && proxyClusterAddr != "" && service.ProxyCluster != proxyClusterAddr { if service.ProxyCluster == "" || service.ProxyCluster != conn.address {
continue continue
} }
filtered = append(filtered, service) filtered = append(filtered, service)
@@ -277,20 +276,10 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return nil return nil
} }
// extractClusterAddr extracts the host from a proxy address URL. // isProxyAddressValid validates a proxy address
func extractClusterAddr(addr string) string { func isProxyAddressValid(addr string) bool {
if addr == "" { _, err := domain.ValidateDomains([]string{addr})
return "" return err == nil
}
u, err := url.Parse(addr)
if err != nil {
return addr
}
host := u.Host
if h, _, err := net.SplitHostPort(host); err == nil {
return h
}
return host
} }
// sender handles sending messages to proxy // sender handles sending messages to proxy

View File

@@ -9,6 +9,7 @@ import (
"strings" "strings"
"syscall" "syscall"
"github.com/netbirdio/netbird/shared/management/domain"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/crypto/acme" "golang.org/x/crypto/acme"
@@ -36,7 +37,7 @@ var (
debugLogs bool debugLogs bool
mgmtAddr string mgmtAddr string
addr string addr string
proxyURL string proxyDomain string
certDir string certDir string
acmeCerts bool acmeCerts bool
acmeAddr string acmeAddr string
@@ -69,7 +70,7 @@ func init() {
rootCmd.PersistentFlags().BoolVar(&debugLogs, "debug", envBoolOrDefault("NB_PROXY_DEBUG_LOGS", false), "Enable debug logs") rootCmd.PersistentFlags().BoolVar(&debugLogs, "debug", envBoolOrDefault("NB_PROXY_DEBUG_LOGS", false), "Enable debug logs")
rootCmd.Flags().StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to") rootCmd.Flags().StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to")
rootCmd.Flags().StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on") rootCmd.Flags().StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on")
rootCmd.Flags().StringVar(&proxyURL, "url", envStringOrDefault("NB_PROXY_URL", ""), "The URL at which this proxy will be reached") rootCmd.Flags().StringVar(&proxyDomain, "domain", envStringOrDefault("NB_PROXY_DOMAIN", ""), "The Domain at which this proxy will be reached. e.g., netbird.example.com")
rootCmd.Flags().StringVar(&certDir, "cert-dir", envStringOrDefault("NB_PROXY_CERTIFICATE_DIRECTORY", "./certs"), "Directory to store certificates") rootCmd.Flags().StringVar(&certDir, "cert-dir", envStringOrDefault("NB_PROXY_CERTIFICATE_DIRECTORY", "./certs"), "Directory to store certificates")
rootCmd.Flags().BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates using HTTP-01 challenges") rootCmd.Flags().BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates using HTTP-01 challenges")
rootCmd.Flags().StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address for ACME HTTP-01 challenges") rootCmd.Flags().StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address for ACME HTTP-01 challenges")
@@ -128,6 +129,11 @@ func runServer(cmd *cobra.Command, args []string) error {
return fmt.Errorf("invalid --forwarded-proto value %q: must be auto, http, or https", forwardedProto) return fmt.Errorf("invalid --forwarded-proto value %q: must be auto, http, or https", forwardedProto)
} }
_, err := domain.ValidateDomains([]string{proxyDomain})
if err != nil {
return fmt.Errorf("invalid domain value %q: %w", proxyDomain, err)
}
parsedTrustedProxies, err := proxy.ParseTrustedProxies(trustedProxies) parsedTrustedProxies, err := proxy.ParseTrustedProxies(trustedProxies)
if err != nil { if err != nil {
return fmt.Errorf("invalid --trusted-proxies: %w", err) return fmt.Errorf("invalid --trusted-proxies: %w", err)
@@ -137,7 +143,7 @@ func runServer(cmd *cobra.Command, args []string) error {
Logger: logger, Logger: logger,
Version: Version, Version: Version,
ManagementAddress: mgmtAddr, ManagementAddress: mgmtAddr,
ProxyURL: proxyURL, ProxyURL: proxyDomain,
ProxyToken: proxyToken, ProxyToken: proxyToken,
CertificateDirectory: certDir, CertificateDirectory: certDir,
CertificateFile: certFile, CertificateFile: certFile,

View File

@@ -218,12 +218,6 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
}() }()
tlsConfig = s.acme.TLSConfig() tlsConfig = s.acme.TLSConfig()
// If the ProxyURL is not set, then fallback to the server address.
// Hopefully that should give at least something that we can use.
// If it doesn't, then autocert probably won't work correctly.
if s.ProxyURL == "" {
s.ProxyURL, _, _ = net.SplitHostPort(addr)
}
// ServerName needs to be set to allow for ACME to work correctly // ServerName needs to be set to allow for ACME to work correctly
// when using CNAME URLs to access the proxy. // when using CNAME URLs to access the proxy.
tlsConfig.ServerName = s.ProxyURL tlsConfig.ServerName = s.ProxyURL