diff --git a/management/cmd/defaults.go b/management/cmd/defaults.go index 424a29507..7b028fd43 100644 --- a/management/cmd/defaults.go +++ b/management/cmd/defaults.go @@ -1,5 +1,11 @@ package cmd +import ( + "fmt" + + "github.com/netbirdio/netbird/management/internals/server" +) + const ( defaultMgmtDataDir = "/var/lib/netbird/" defaultMgmtConfigDir = "/etc/netbird" @@ -16,3 +22,5 @@ const ( defaultSingleAccModeDomain = "netbird.selfhosted" ) + +var defaultWSProxyAddr = fmt.Sprintf("127.0.0.1:%d", server.ManagementLegacyPort) diff --git a/management/cmd/management.go b/management/cmd/management.go index 37ba0ae16..5f3a0074d 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "io/fs" + "net" "net/http" "net/url" "os" @@ -26,11 +27,11 @@ import ( "github.com/netbirdio/netbird/util" ) -var newServer = func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server { - return server.NewServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled) +var newServer = func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool, wsProxyAddr string) server.Server { + return server.NewServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled, wsProxyAddr) } -func SetNewServer(fn func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server) { +func SetNewServer(fn func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool, wsProxyAddr string) server.Server) { newServer = fn } @@ -82,6 +83,10 @@ var ( return fmt.Errorf("failed parsing the provided dns-domain. Valid status: %t, Length: %d", valid, len(dnsDomain)) } + if _, _, err := net.SplitHostPort(wsProxyAddr); err != nil { + return fmt.Errorf("invalid ws-proxy-backend-addr format: %w", err) + } + return nil }, RunE: func(cmd *cobra.Command, args []string) error { @@ -108,7 +113,7 @@ var ( mgmtSingleAccModeDomain = "" } - srv := newServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled) + srv := newServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled, wsProxyAddr) go func() { if err := srv.Start(cmd.Context()); err != nil { log.Fatalf("Server error: %v", err) diff --git a/management/cmd/root.go b/management/cmd/root.go index b60f79c23..2e289d8c5 100644 --- a/management/cmd/root.go +++ b/management/cmd/root.go @@ -31,6 +31,7 @@ var ( mgmtSingleAccModeDomain string certFile string certKey string + wsProxyAddr string rootCmd = &cobra.Command{ Use: "netbird-mgmt", @@ -57,6 +58,7 @@ func init() { mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics") mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location") + mgmtCmd.Flags().StringVar(&wsProxyAddr, "ws-proxy-backend-addr", defaultWSProxyAddr, "WebSocket proxy backend address in host:port format that the proxy will dial") mgmtCmd.Flags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file") mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain) diff --git a/management/internals/server/server.go b/management/internals/server/server.go index 94c633fc6..41577d7f0 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -6,7 +6,6 @@ import ( "fmt" "net" "net/http" - "net/netip" "strings" "sync" "time" @@ -58,6 +57,7 @@ type BaseServer struct { mgmtSingleAccModeDomain string mgmtMetricsPort int mgmtPort int + wsProxyAddr string listener net.Listener certManager *autocert.Manager @@ -69,7 +69,7 @@ type BaseServer struct { } // NewServer initializes and configures a new Server instance -func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer { +func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool, wsProxyAddr string) *BaseServer { return &BaseServer{ config: config, container: make(map[string]any), @@ -80,6 +80,7 @@ func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain strin userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, mgmtPort: mgmtPort, mgmtMetricsPort: mgmtMetricsPort, + wsProxyAddr: wsProxyAddr, } } @@ -182,6 +183,7 @@ func (s *BaseServer) Start(ctx context.Context) error { log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion()) log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String()) + log.WithContext(ctx).Infof("WebSocket proxy configured to forward to: %s", s.wsProxyAddr) s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled) s.update = version.NewUpdate("nb/management") @@ -252,7 +254,7 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config) } func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler { - wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter)) + wsProxy := wsproxyserver.New(s.wsProxyAddr, wsproxyserver.WithOTelMeter(meter)) return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { switch { diff --git a/util/wsproxy/server/proxy.go b/util/wsproxy/server/proxy.go index 977440a60..f19a0525c 100644 --- a/util/wsproxy/server/proxy.go +++ b/util/wsproxy/server/proxy.go @@ -6,7 +6,6 @@ import ( "io" "net" "net/http" - "net/netip" "sync" "time" @@ -23,7 +22,7 @@ const ( // Config contains the configuration for the WebSocket proxy. type Config struct { - LocalGRPCAddr netip.AddrPort + LocalGRPCAddr string Path string MetricsRecorder MetricsRecorder } @@ -35,7 +34,7 @@ type Proxy struct { } // New creates a new WebSocket proxy instance with optional configuration -func New(localGRPCAddr netip.AddrPort, opts ...Option) *Proxy { +func New(localGRPCAddr string, opts ...Option) *Proxy { config := Config{ LocalGRPCAddr: localGRPCAddr, Path: wsproxy.ProxyPath, @@ -81,7 +80,7 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) { }() log.Debugf("WebSocket proxy attempting to connect to local gRPC at %s", p.config.LocalGRPCAddr) - tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr.String(), dialTimeout) + tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr, dialTimeout) if err != nil { p.metrics.RecordError(ctx, "tcp_dial_failed") log.Warnf("Failed to connect to local gRPC server at %s: %v", p.config.LocalGRPCAddr, err)