mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-23 02:36:42 +00:00
rename url flag to domain and update validation
This commit is contained in:
@@ -9,13 +9,13 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
@@ -167,12 +167,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
}
|
}
|
||||||
|
|
||||||
proxyAddress := req.GetAddress()
|
proxyAddress := req.GetAddress()
|
||||||
log.WithFields(log.Fields{
|
if !isProxyAddressValid(proxyAddress) {
|
||||||
"proxy_id": proxyID,
|
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||||
"address": proxyAddress,
|
}
|
||||||
"version": req.GetVersion(),
|
|
||||||
"started": req.GetStartedAt().AsTime(),
|
|
||||||
}).Info("Proxy connected")
|
|
||||||
|
|
||||||
connCtx, cancel := context.WithCancel(ctx)
|
connCtx, cancel := context.WithCancel(ctx)
|
||||||
conn := &proxyConnection{
|
conn := &proxyConnection{
|
||||||
@@ -189,7 +186,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"proxy_id": proxyID,
|
"proxy_id": proxyID,
|
||||||
"address": proxyAddress,
|
"address": proxyAddress,
|
||||||
"cluster_addr": extractClusterAddr(proxyAddress),
|
"cluster_addr": proxyAddress,
|
||||||
"total_proxies": len(s.GetConnectedProxies()),
|
"total_proxies": len(s.GetConnectedProxies()),
|
||||||
}).Info("Proxy registered in cluster")
|
}).Info("Proxy registered in cluster")
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -222,14 +219,16 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
|||||||
return fmt.Errorf("get services from store: %w", err)
|
return fmt.Errorf("get services from store: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyClusterAddr := extractClusterAddr(conn.address)
|
if !isProxyAddressValid(conn.address) {
|
||||||
|
return fmt.Errorf("proxy address is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
var filtered []*reverseproxy.Service
|
var filtered []*reverseproxy.Service
|
||||||
for _, service := range services {
|
for _, service := range services {
|
||||||
if !service.Enabled {
|
if !service.Enabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if service.ProxyCluster != "" && proxyClusterAddr != "" && service.ProxyCluster != proxyClusterAddr {
|
if service.ProxyCluster == "" || service.ProxyCluster != conn.address {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
filtered = append(filtered, service)
|
filtered = append(filtered, service)
|
||||||
@@ -277,20 +276,10 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractClusterAddr extracts the host from a proxy address URL.
|
// isProxyAddressValid validates a proxy address
|
||||||
func extractClusterAddr(addr string) string {
|
func isProxyAddressValid(addr string) bool {
|
||||||
if addr == "" {
|
_, err := domain.ValidateDomains([]string{addr})
|
||||||
return ""
|
return err == nil
|
||||||
}
|
|
||||||
u, err := url.Parse(addr)
|
|
||||||
if err != nil {
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
host := u.Host
|
|
||||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
return host
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sender handles sending messages to proxy
|
// sender handles sending messages to proxy
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"golang.org/x/crypto/acme"
|
"golang.org/x/crypto/acme"
|
||||||
@@ -36,7 +37,7 @@ var (
|
|||||||
debugLogs bool
|
debugLogs bool
|
||||||
mgmtAddr string
|
mgmtAddr string
|
||||||
addr string
|
addr string
|
||||||
proxyURL string
|
proxyDomain string
|
||||||
certDir string
|
certDir string
|
||||||
acmeCerts bool
|
acmeCerts bool
|
||||||
acmeAddr string
|
acmeAddr string
|
||||||
@@ -69,7 +70,7 @@ func init() {
|
|||||||
rootCmd.PersistentFlags().BoolVar(&debugLogs, "debug", envBoolOrDefault("NB_PROXY_DEBUG_LOGS", false), "Enable debug logs")
|
rootCmd.PersistentFlags().BoolVar(&debugLogs, "debug", envBoolOrDefault("NB_PROXY_DEBUG_LOGS", false), "Enable debug logs")
|
||||||
rootCmd.Flags().StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to")
|
rootCmd.Flags().StringVar(&mgmtAddr, "mgmt", envStringOrDefault("NB_PROXY_MANAGEMENT_ADDRESS", DefaultManagementURL), "Management address to connect to")
|
||||||
rootCmd.Flags().StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on")
|
rootCmd.Flags().StringVar(&addr, "addr", envStringOrDefault("NB_PROXY_ADDRESS", ":443"), "Reverse proxy address to listen on")
|
||||||
rootCmd.Flags().StringVar(&proxyURL, "url", envStringOrDefault("NB_PROXY_URL", ""), "The URL at which this proxy will be reached")
|
rootCmd.Flags().StringVar(&proxyDomain, "domain", envStringOrDefault("NB_PROXY_DOMAIN", ""), "The Domain at which this proxy will be reached. e.g., netbird.example.com")
|
||||||
rootCmd.Flags().StringVar(&certDir, "cert-dir", envStringOrDefault("NB_PROXY_CERTIFICATE_DIRECTORY", "./certs"), "Directory to store certificates")
|
rootCmd.Flags().StringVar(&certDir, "cert-dir", envStringOrDefault("NB_PROXY_CERTIFICATE_DIRECTORY", "./certs"), "Directory to store certificates")
|
||||||
rootCmd.Flags().BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates using HTTP-01 challenges")
|
rootCmd.Flags().BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates using HTTP-01 challenges")
|
||||||
rootCmd.Flags().StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address for ACME HTTP-01 challenges")
|
rootCmd.Flags().StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address for ACME HTTP-01 challenges")
|
||||||
@@ -128,6 +129,11 @@ func runServer(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("invalid --forwarded-proto value %q: must be auto, http, or https", forwardedProto)
|
return fmt.Errorf("invalid --forwarded-proto value %q: must be auto, http, or https", forwardedProto)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err := domain.ValidateDomains([]string{proxyDomain})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid domain value %q: %w", proxyDomain, err)
|
||||||
|
}
|
||||||
|
|
||||||
parsedTrustedProxies, err := proxy.ParseTrustedProxies(trustedProxies)
|
parsedTrustedProxies, err := proxy.ParseTrustedProxies(trustedProxies)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid --trusted-proxies: %w", err)
|
return fmt.Errorf("invalid --trusted-proxies: %w", err)
|
||||||
@@ -137,7 +143,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
|||||||
Logger: logger,
|
Logger: logger,
|
||||||
Version: Version,
|
Version: Version,
|
||||||
ManagementAddress: mgmtAddr,
|
ManagementAddress: mgmtAddr,
|
||||||
ProxyURL: proxyURL,
|
ProxyURL: proxyDomain,
|
||||||
ProxyToken: proxyToken,
|
ProxyToken: proxyToken,
|
||||||
CertificateDirectory: certDir,
|
CertificateDirectory: certDir,
|
||||||
CertificateFile: certFile,
|
CertificateFile: certFile,
|
||||||
|
|||||||
@@ -218,12 +218,6 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
|||||||
}()
|
}()
|
||||||
tlsConfig = s.acme.TLSConfig()
|
tlsConfig = s.acme.TLSConfig()
|
||||||
|
|
||||||
// If the ProxyURL is not set, then fallback to the server address.
|
|
||||||
// Hopefully that should give at least something that we can use.
|
|
||||||
// If it doesn't, then autocert probably won't work correctly.
|
|
||||||
if s.ProxyURL == "" {
|
|
||||||
s.ProxyURL, _, _ = net.SplitHostPort(addr)
|
|
||||||
}
|
|
||||||
// ServerName needs to be set to allow for ACME to work correctly
|
// ServerName needs to be set to allow for ACME to work correctly
|
||||||
// when using CNAME URLs to access the proxy.
|
// when using CNAME URLs to access the proxy.
|
||||||
tlsConfig.ServerName = s.ProxyURL
|
tlsConfig.ServerName = s.ProxyURL
|
||||||
|
|||||||
Reference in New Issue
Block a user