diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index f957e66a5..3a8568979 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -136,6 +136,7 @@ type sessionState struct { type Server struct { sshServer *ssh.Server + listener net.Listener mu sync.RWMutex hostKeyPEM []byte @@ -151,7 +152,6 @@ type Server struct { // Populated at authentication time, stores JWT username and port forwards for status display. connections map[connKey]*connState - allowLocalPortForwarding bool allowRemotePortForwarding bool allowRootLogin bool @@ -240,6 +240,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error { return fmt.Errorf("create SSH server: %w", err) } + s.listener = ln s.sshServer = sshServer log.Infof("SSH server started on %s", addrDesc) @@ -292,6 +293,7 @@ func (s *Server) Stop() error { } s.sshServer = nil + s.listener = nil maps.Clear(s.sessions) maps.Clear(s.pendingAuthJWT) @@ -307,6 +309,18 @@ func (s *Server) Stop() error { return nil } +// Addr returns the address the SSH server is listening on, or nil if the server is not running +func (s *Server) Addr() net.Addr { + s.mu.RLock() + defer s.mu.RUnlock() + + if s.listener == nil { + return nil + } + + return s.listener.Addr() +} + // GetStatus returns the current status of the SSH server and active sessions. func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) { s.mu.RLock() diff --git a/client/ssh/server/test.go b/client/ssh/server/test.go index 20930c721..f8abd1752 100644 --- a/client/ssh/server/test.go +++ b/client/ssh/server/test.go @@ -3,7 +3,6 @@ package server import ( "context" "fmt" - "net" "net/netip" "testing" "time" @@ -14,23 +13,21 @@ func StartTestServer(t *testing.T, server *Server) string { errChan := make(chan error, 1) go func() { - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - errChan <- err - return - } - actualAddr := ln.Addr().String() - if err := ln.Close(); err != nil { - errChan <- fmt.Errorf("close temp listener: %w", err) - return - } - - addrPort := netip.MustParseAddrPort(actualAddr) + // Use port 0 to let the OS assign a free port + addrPort := netip.MustParseAddrPort("127.0.0.1:0") if err := server.Start(context.Background(), addrPort); err != nil { errChan <- err return } - started <- actualAddr + + // Get the actual listening address from the server + actualAddr := server.Addr() + if actualAddr == nil { + errChan <- fmt.Errorf("server started but no listener address available") + return + } + + started <- actualAddr.String() }() select {