Refactor SSH server to manage listener lifecycle and expose active address via Addr method. (#5036)

This commit is contained in:
Zoltan Papp
2026-01-07 15:34:26 +01:00
committed by GitHub
parent e586c20e36
commit 6ff9aa0366
2 changed files with 26 additions and 15 deletions

View File

@@ -136,6 +136,7 @@ type sessionState struct {
type Server struct {
sshServer *ssh.Server
listener net.Listener
mu sync.RWMutex
hostKeyPEM []byte
@@ -151,7 +152,6 @@ type Server struct {
// Populated at authentication time, stores JWT username and port forwards for status display.
connections map[connKey]*connState
allowLocalPortForwarding bool
allowRemotePortForwarding bool
allowRootLogin bool
@@ -240,6 +240,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
return fmt.Errorf("create SSH server: %w", err)
}
s.listener = ln
s.sshServer = sshServer
log.Infof("SSH server started on %s", addrDesc)
@@ -292,6 +293,7 @@ func (s *Server) Stop() error {
}
s.sshServer = nil
s.listener = nil
maps.Clear(s.sessions)
maps.Clear(s.pendingAuthJWT)
@@ -307,6 +309,18 @@ func (s *Server) Stop() error {
return nil
}
// Addr returns the address the SSH server is listening on, or nil if the server is not running
func (s *Server) Addr() net.Addr {
s.mu.RLock()
defer s.mu.RUnlock()
if s.listener == nil {
return nil
}
return s.listener.Addr()
}
// GetStatus returns the current status of the SSH server and active sessions.
func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
s.mu.RLock()

View File

@@ -3,7 +3,6 @@ package server
import (
"context"
"fmt"
"net"
"net/netip"
"testing"
"time"
@@ -14,23 +13,21 @@ func StartTestServer(t *testing.T, server *Server) string {
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort := netip.MustParseAddrPort(actualAddr)
// Use port 0 to let the OS assign a free port
addrPort := netip.MustParseAddrPort("127.0.0.1:0")
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
// Get the actual listening address from the server
actualAddr := server.Addr()
if actualAddr == nil {
errChan <- fmt.Errorf("server started but no listener address available")
return
}
started <- actualAddr.String()
}()
select {