diff --git a/combined/cmd/root.go b/combined/cmd/root.go index db986b4d4..78223880a 100644 --- a/combined/cmd/root.go +++ b/combined/cmd/root.go @@ -67,6 +67,10 @@ func init() { rootCmd.AddCommand(newTokenCommands()) } +func RootCmd() *cobra.Command { + return rootCmd +} + func Execute() error { return rootCmd.Execute() } @@ -168,7 +172,7 @@ func initializeConfig() error { // serverInstances holds all server instances created during startup. type serverInstances struct { relaySrv *relayServer.Server - mgmtSrv *mgmtServer.BaseServer + mgmtSrv mgmtServer.Server signalSrv *signalServer.Server healthcheck *healthcheck.Server stunServer *stun.Server @@ -324,19 +328,24 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) { return } - servers.mgmtSrv.AfterInit(func(s *mgmtServer.BaseServer) { - grpcSrv := s.GRPCServer() + if s, ok := servers.mgmtSrv.GetContainer(mgmtServer.ContainerKeyBaseServer); ok { + if baseServer, ok := s.(*mgmtServer.BaseServer); ok { + baseServer.AfterInit(func(s *mgmtServer.BaseServer) { + grpcSrv := s.GRPCServer() - if servers.signalSrv != nil { - proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv) - log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress) - } + if servers.signalSrv != nil { + proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv) + log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress) + } - s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg)) - if servers.relaySrv != nil { - log.Infof("Relay WebSocket handler added (path: /relay)") + s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg)) + if servers.relaySrv != nil { + log.Infof("Relay WebSocket handler added (path: /relay)") + } + }) } - }) + } + } func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, metricsServer *sharedMetrics.Metrics) { @@ -377,7 +386,7 @@ func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck * } } -func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv *mgmtServer.BaseServer, metricsServer *sharedMetrics.Metrics) error { +func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv mgmtServer.Server, metricsServer *sharedMetrics.Metrics) error { var errs error if err := httpHealthcheck.Shutdown(ctx); err != nil { @@ -491,7 +500,7 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) { return nil, false, nil } -func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) { +func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (mgmtServer.Server, error) { mgmt := cfg.Management // Extract port from listen address @@ -502,7 +511,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (* } mgmtPort, _ := strconv.Atoi(portStr) - mgmtSrv := mgmtServer.NewServer( + mgmtSrv := newServer( &mgmtServer.Config{ NbConfig: mgmtConfig, DNSDomain: "", diff --git a/management/internals/server/server.go b/management/internals/server/server.go index 9b8716da1..ef31ee4f2 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -34,6 +34,8 @@ const ( ManagementLegacyPort = 33073 // DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs. DefaultSelfHostedDomain = "netbird.selfhosted" + + ContainerKeyBaseServer = "baseServer" ) type Server interface { @@ -91,7 +93,7 @@ type Config struct { // NewServer initializes and configures a new Server instance func NewServer(cfg *Config) *BaseServer { - return &BaseServer{ + s := &BaseServer{ Config: cfg.NbConfig, container: make(map[string]any), dnsDomain: cfg.DNSDomain, @@ -104,6 +106,9 @@ func NewServer(cfg *Config) *BaseServer { mgmtMetricsPort: cfg.MgmtMetricsPort, autoResolveDomains: cfg.AutoResolveDomains, } + s.container[ContainerKeyBaseServer] = s + + return s } func (s *BaseServer) AfterInit(fn func(s *BaseServer)) {