From e9dbf9db6f43dc3799634ad6bca2beae3cc4c88d Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 29 May 2026 17:35:35 +0300 Subject: [PATCH] [management] Extend combined server initialization (#6156) --- combined/cmd/root.go | 55 ++++++++++++++------------- combined/cmd/server.go | 13 +++++++ management/internals/server/server.go | 7 +++- 3 files changed, 48 insertions(+), 27 deletions(-) create mode 100644 combined/cmd/server.go diff --git a/combined/cmd/root.go b/combined/cmd/root.go index 78290388b..31e0580fb 100644 --- a/combined/cmd/root.go +++ b/combined/cmd/root.go @@ -67,6 +67,10 @@ func init() { rootCmd.AddCommand(newTokenCommands()) } +func RootCmd() *cobra.Command { + return rootCmd +} + func Execute() error { return rootCmd.Execute() } @@ -168,7 +172,7 @@ func initializeConfig() error { // serverInstances holds all server instances created during startup. type serverInstances struct { relaySrv *relayServer.Server - mgmtSrv *mgmtServer.BaseServer + mgmtSrv mgmtServer.Server signalSrv *signalServer.Server healthcheck *healthcheck.Server stunServer *stun.Server @@ -324,19 +328,24 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) { return } - servers.mgmtSrv.AfterInit(func(s *mgmtServer.BaseServer) { - grpcSrv := s.GRPCServer() + if s, ok := servers.mgmtSrv.GetContainer(mgmtServer.ContainerKeyBaseServer); ok { + if baseServer, ok := s.(*mgmtServer.BaseServer); ok { + baseServer.AfterInit(func(s *mgmtServer.BaseServer) { + grpcSrv := s.GRPCServer() - if servers.signalSrv != nil { - proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv) - log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress) - } + if servers.signalSrv != nil { + proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv) + log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress) + } - s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg)) - if servers.relaySrv != nil { - log.Infof("Relay WebSocket handler added (path: /relay)") + s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg)) + if servers.relaySrv != nil { + log.Infof("Relay WebSocket handler added (path: /relay)") + } + }) } - }) + } + } func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, metricsServer *sharedMetrics.Metrics) { @@ -346,38 +355,32 @@ func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck * log.Infof("Relay WebSocket multiplexed on management port (no separate relay listener)") } - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint) if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { log.Fatalf("failed to start metrics server: %v", err) } - }() + }) - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { log.Fatalf("failed to start healthcheck server: %v", err) } - }() + }) if stunServer != nil { - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { if err := stunServer.Listen(); err != nil { if errors.Is(err, stun.ErrServerClosed) { return } log.Errorf("STUN server error: %v", err) } - }() + }) } } -func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv *mgmtServer.BaseServer, metricsServer *sharedMetrics.Metrics) error { +func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv mgmtServer.Server, metricsServer *sharedMetrics.Metrics) error { var errs error if err := httpHealthcheck.Shutdown(ctx); err != nil { @@ -491,7 +494,7 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) { return nil, false, nil } -func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) { +func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (mgmtServer.Server, error) { mgmt := cfg.Management // Extract port from listen address @@ -502,7 +505,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (* } mgmtPort, _ := strconv.Atoi(portStr) - mgmtSrv := mgmtServer.NewServer( + mgmtSrv := newServer( &mgmtServer.Config{ NbConfig: mgmtConfig, DNSDomain: "", diff --git a/combined/cmd/server.go b/combined/cmd/server.go new file mode 100644 index 000000000..f9384dfb1 --- /dev/null +++ b/combined/cmd/server.go @@ -0,0 +1,13 @@ +package cmd + +import ( + mgmtServer "github.com/netbirdio/netbird/management/internals/server" +) + +var newServer = func(cfg *mgmtServer.Config) mgmtServer.Server { + return mgmtServer.NewServer(cfg) +} + +func SetNewServer(fn func(*mgmtServer.Config) mgmtServer.Server) { + newServer = fn +} diff --git a/management/internals/server/server.go b/management/internals/server/server.go index 63d13baab..43ee2126d 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -34,6 +34,8 @@ const ( ManagementLegacyPort = 33073 // DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs. DefaultSelfHostedDomain = "netbird.selfhosted" + + ContainerKeyBaseServer = "baseServer" ) type Server interface { @@ -91,7 +93,7 @@ type Config struct { // NewServer initializes and configures a new Server instance func NewServer(cfg *Config) *BaseServer { - return &BaseServer{ + s := &BaseServer{ Config: cfg.NbConfig, container: make(map[string]any), dnsDomain: cfg.DNSDomain, @@ -104,6 +106,9 @@ func NewServer(cfg *Config) *BaseServer { mgmtMetricsPort: cfg.MgmtMetricsPort, autoResolveDomains: cfg.AutoResolveDomains, } + s.container[ContainerKeyBaseServer] = s + + return s } func (s *BaseServer) AfterInit(fn func(s *BaseServer)) {