mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Refactor SSH server to manage listener lifecycle and expose active address via Addr method. (#5036)
This commit is contained in:
@@ -136,6 +136,7 @@ type sessionState struct {
|
||||
|
||||
type Server struct {
|
||||
sshServer *ssh.Server
|
||||
listener net.Listener
|
||||
mu sync.RWMutex
|
||||
hostKeyPEM []byte
|
||||
|
||||
@@ -151,7 +152,6 @@ type Server struct {
|
||||
// Populated at authentication time, stores JWT username and port forwards for status display.
|
||||
connections map[connKey]*connState
|
||||
|
||||
|
||||
allowLocalPortForwarding bool
|
||||
allowRemotePortForwarding bool
|
||||
allowRootLogin bool
|
||||
@@ -240,6 +240,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||
return fmt.Errorf("create SSH server: %w", err)
|
||||
}
|
||||
|
||||
s.listener = ln
|
||||
s.sshServer = sshServer
|
||||
log.Infof("SSH server started on %s", addrDesc)
|
||||
|
||||
@@ -292,6 +293,7 @@ func (s *Server) Stop() error {
|
||||
}
|
||||
|
||||
s.sshServer = nil
|
||||
s.listener = nil
|
||||
|
||||
maps.Clear(s.sessions)
|
||||
maps.Clear(s.pendingAuthJWT)
|
||||
@@ -307,6 +309,18 @@ func (s *Server) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Addr returns the address the SSH server is listening on, or nil if the server is not running
|
||||
func (s *Server) Addr() net.Addr {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.listener == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.listener.Addr()
|
||||
}
|
||||
|
||||
// GetStatus returns the current status of the SSH server and active sessions.
|
||||
func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
|
||||
s.mu.RLock()
|
||||
|
||||
@@ -3,7 +3,6 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -14,23 +13,21 @@ func StartTestServer(t *testing.T, server *Server) string {
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrPort := netip.MustParseAddrPort(actualAddr)
|
||||
// Use port 0 to let the OS assign a free port
|
||||
addrPort := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
|
||||
// Get the actual listening address from the server
|
||||
actualAddr := server.Addr()
|
||||
if actualAddr == nil {
|
||||
errChan <- fmt.Errorf("server started but no listener address available")
|
||||
return
|
||||
}
|
||||
|
||||
started <- actualAddr.String()
|
||||
}()
|
||||
|
||||
select {
|
||||
|
||||
Reference in New Issue
Block a user