mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 17:56:39 +00:00
[client] Add port forwarding to ssh proxy (#5031)
* Implement port forwarding for the ssh proxy * Allow user switching for port forwarding
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -40,6 +41,11 @@ const (
|
||||
|
||||
msgPrivilegedUserDisabled = "privileged user login is disabled"
|
||||
|
||||
cmdInteractiveShell = "<interactive shell>"
|
||||
cmdPortForwarding = "<port forwarding>"
|
||||
cmdSFTP = "<sftp>"
|
||||
cmdNonInteractive = "<idle>"
|
||||
|
||||
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
|
||||
DefaultJWTMaxTokenAge = 5 * 60
|
||||
)
|
||||
@@ -90,10 +96,10 @@ func logSessionExitError(logger *log.Entry, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// safeLogCommand returns a safe representation of the command for logging
|
||||
// safeLogCommand returns a safe representation of the command for logging.
|
||||
func safeLogCommand(cmd []string) string {
|
||||
if len(cmd) == 0 {
|
||||
return "<interactive shell>"
|
||||
return cmdInteractiveShell
|
||||
}
|
||||
if len(cmd) == 1 {
|
||||
return cmd[0]
|
||||
@@ -101,26 +107,50 @@ func safeLogCommand(cmd []string) string {
|
||||
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
|
||||
}
|
||||
|
||||
type sshConnectionState struct {
|
||||
hasActivePortForward bool
|
||||
username string
|
||||
remoteAddr string
|
||||
// connState tracks the state of an SSH connection for port forwarding and status display.
|
||||
type connState struct {
|
||||
username string
|
||||
remoteAddr net.Addr
|
||||
portForwards []string
|
||||
jwtUsername string
|
||||
}
|
||||
|
||||
// authKey uniquely identifies an authentication attempt by username and remote address.
|
||||
// Used to temporarily store JWT username between passwordHandler and sessionHandler.
|
||||
type authKey string
|
||||
|
||||
// connKey uniquely identifies an SSH connection by its remote address.
|
||||
// Used to track authenticated connections for status display and port forwarding.
|
||||
type connKey string
|
||||
|
||||
func newAuthKey(username string, remoteAddr net.Addr) authKey {
|
||||
return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String()))
|
||||
}
|
||||
|
||||
// sessionState tracks an active SSH session (shell, command, or subsystem like SFTP).
|
||||
type sessionState struct {
|
||||
session ssh.Session
|
||||
sessionType string
|
||||
jwtUsername string
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
sshServer *ssh.Server
|
||||
mu sync.RWMutex
|
||||
hostKeyPEM []byte
|
||||
sessions map[SessionKey]ssh.Session
|
||||
sessionCancels map[ConnectionKey]context.CancelFunc
|
||||
sessionJWTUsers map[SessionKey]string
|
||||
pendingAuthJWT map[authKey]string
|
||||
sshServer *ssh.Server
|
||||
mu sync.RWMutex
|
||||
hostKeyPEM []byte
|
||||
|
||||
// sessions tracks active SSH sessions (shell, command, SFTP).
|
||||
// These are created when a client opens a session channel and requests shell/exec/subsystem.
|
||||
sessions map[sessionKey]*sessionState
|
||||
|
||||
// pendingAuthJWT temporarily stores JWT username during the auth→session handoff.
|
||||
// Populated in passwordHandler, consumed in sessionHandler/sftpSubsystemHandler.
|
||||
pendingAuthJWT map[authKey]string
|
||||
|
||||
// connections tracks all SSH connections by their remote address.
|
||||
// Populated at authentication time, stores JWT username and port forwards for status display.
|
||||
connections map[connKey]*connState
|
||||
|
||||
|
||||
allowLocalPortForwarding bool
|
||||
allowRemotePortForwarding bool
|
||||
@@ -132,8 +162,7 @@ type Server struct {
|
||||
|
||||
wgAddress wgaddr.Address
|
||||
|
||||
remoteForwardListeners map[ForwardKey]net.Listener
|
||||
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
|
||||
remoteForwardListeners map[forwardKey]net.Listener
|
||||
|
||||
jwtValidator *jwt.Validator
|
||||
jwtExtractor *jwt.ClaimsExtractor
|
||||
@@ -167,6 +196,7 @@ type SessionInfo struct {
|
||||
RemoteAddress string
|
||||
Command string
|
||||
JWTUsername string
|
||||
PortForwards []string
|
||||
}
|
||||
|
||||
// New creates an SSH server instance with the provided host key and optional JWT configuration
|
||||
@@ -175,11 +205,10 @@ func New(config *Config) *Server {
|
||||
s := &Server{
|
||||
mu: sync.RWMutex{},
|
||||
hostKeyPEM: config.HostKeyPEM,
|
||||
sessions: make(map[SessionKey]ssh.Session),
|
||||
sessionJWTUsers: make(map[SessionKey]string),
|
||||
sessions: make(map[sessionKey]*sessionState),
|
||||
pendingAuthJWT: make(map[authKey]string),
|
||||
remoteForwardListeners: make(map[ForwardKey]net.Listener),
|
||||
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
||||
remoteForwardListeners: make(map[forwardKey]net.Listener),
|
||||
connections: make(map[connKey]*connState),
|
||||
jwtEnabled: config.JWT != nil,
|
||||
jwtConfig: config.JWT,
|
||||
authorizer: sshauth.NewAuthorizer(), // Initialize with empty config
|
||||
@@ -265,14 +294,8 @@ func (s *Server) Stop() error {
|
||||
s.sshServer = nil
|
||||
|
||||
maps.Clear(s.sessions)
|
||||
maps.Clear(s.sessionJWTUsers)
|
||||
maps.Clear(s.pendingAuthJWT)
|
||||
maps.Clear(s.sshConnections)
|
||||
|
||||
for _, cancelFunc := range s.sessionCancels {
|
||||
cancelFunc()
|
||||
}
|
||||
maps.Clear(s.sessionCancels)
|
||||
maps.Clear(s.connections)
|
||||
|
||||
for _, listener := range s.remoteForwardListeners {
|
||||
if err := listener.Close(); err != nil {
|
||||
@@ -284,32 +307,70 @@ func (s *Server) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStatus returns the current status of the SSH server and active sessions
|
||||
// GetStatus returns the current status of the SSH server and active sessions.
|
||||
func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
enabled = s.sshServer != nil
|
||||
reportedAddrs := make(map[string]bool)
|
||||
|
||||
for sessionKey, session := range s.sessions {
|
||||
cmd := "<interactive shell>"
|
||||
if len(session.Command()) > 0 {
|
||||
cmd = safeLogCommand(session.Command())
|
||||
for _, state := range s.sessions {
|
||||
info := s.buildSessionInfo(state)
|
||||
reportedAddrs[info.RemoteAddress] = true
|
||||
sessions = append(sessions, info)
|
||||
}
|
||||
|
||||
// Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only)
|
||||
for key, connState := range s.connections {
|
||||
remoteAddr := string(key)
|
||||
if reportedAddrs[remoteAddr] {
|
||||
continue
|
||||
}
|
||||
cmd := cmdNonInteractive
|
||||
if len(connState.portForwards) > 0 {
|
||||
cmd = cmdPortForwarding
|
||||
}
|
||||
|
||||
jwtUsername := s.sessionJWTUsers[sessionKey]
|
||||
|
||||
sessions = append(sessions, SessionInfo{
|
||||
Username: session.User(),
|
||||
RemoteAddress: session.RemoteAddr().String(),
|
||||
Username: connState.username,
|
||||
RemoteAddress: remoteAddr,
|
||||
Command: cmd,
|
||||
JWTUsername: jwtUsername,
|
||||
JWTUsername: connState.jwtUsername,
|
||||
PortForwards: connState.portForwards,
|
||||
})
|
||||
}
|
||||
|
||||
return enabled, sessions
|
||||
}
|
||||
|
||||
func (s *Server) buildSessionInfo(state *sessionState) SessionInfo {
|
||||
session := state.session
|
||||
cmd := state.sessionType
|
||||
if cmd == "" {
|
||||
cmd = safeLogCommand(session.Command())
|
||||
}
|
||||
|
||||
remoteAddr := session.RemoteAddr().String()
|
||||
info := SessionInfo{
|
||||
Username: session.User(),
|
||||
RemoteAddress: remoteAddr,
|
||||
Command: cmd,
|
||||
JWTUsername: state.jwtUsername,
|
||||
}
|
||||
|
||||
connState, exists := s.connections[connKey(remoteAddr)]
|
||||
if !exists {
|
||||
return info
|
||||
}
|
||||
|
||||
info.PortForwards = connState.portForwards
|
||||
if len(connState.portForwards) > 0 && (cmd == cmdInteractiveShell || cmd == cmdNonInteractive) {
|
||||
info.Command = cmdPortForwarding
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// SetNetstackNet sets the netstack network for userspace networking
|
||||
func (s *Server) SetNetstackNet(net *netstack.Net) {
|
||||
s.mu.Lock()
|
||||
@@ -520,69 +581,129 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int
|
||||
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
|
||||
osUsername := ctx.User()
|
||||
remoteAddr := ctx.RemoteAddr()
|
||||
logger := s.getRequestLogger(ctx)
|
||||
|
||||
if err := s.ensureJWTValidator(); err != nil {
|
||||
log.Errorf("JWT validator initialization failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
||||
logger.Errorf("JWT validator initialization failed: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
token, err := s.validateJWTToken(password)
|
||||
if err != nil {
|
||||
log.Warnf("JWT authentication failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
||||
logger.Warnf("JWT authentication failed: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
userAuth, err := s.extractAndValidateUser(token)
|
||||
if err != nil {
|
||||
log.Warnf("User validation failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
||||
logger.Warnf("user validation failed: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
logger = logger.WithField("jwt_user", userAuth.UserId)
|
||||
|
||||
s.mu.RLock()
|
||||
authorizer := s.authorizer
|
||||
s.mu.RUnlock()
|
||||
|
||||
if err := authorizer.Authorize(userAuth.UserId, osUsername); err != nil {
|
||||
log.Warnf("SSH authorization denied for user %s (JWT user ID: %s) from %s: %v", osUsername, userAuth.UserId, remoteAddr, err)
|
||||
msg, err := authorizer.Authorize(userAuth.UserId, osUsername)
|
||||
if err != nil {
|
||||
logger.Warnf("SSH auth denied: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
logger.Infof("SSH auth %s", msg)
|
||||
|
||||
key := newAuthKey(osUsername, remoteAddr)
|
||||
remoteAddrStr := ctx.RemoteAddr().String()
|
||||
s.mu.Lock()
|
||||
s.pendingAuthJWT[key] = userAuth.UserId
|
||||
s.connections[connKey(remoteAddrStr)] = &connState{
|
||||
username: ctx.User(),
|
||||
remoteAddr: ctx.RemoteAddr(),
|
||||
jwtUsername: userAuth.UserId,
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", osUsername, userAuth.UserId, remoteAddr)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
|
||||
func (s *Server) addConnectionPortForward(username string, remoteAddr net.Addr, forwardAddr string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if state, exists := s.sshConnections[sshConn]; exists {
|
||||
state.hasActivePortForward = true
|
||||
} else {
|
||||
s.sshConnections[sshConn] = &sshConnectionState{
|
||||
hasActivePortForward: true,
|
||||
username: username,
|
||||
remoteAddr: remoteAddr,
|
||||
key := connKey(remoteAddr.String())
|
||||
if state, exists := s.connections[key]; exists {
|
||||
if !slices.Contains(state.portForwards, forwardAddr) {
|
||||
state.portForwards = append(state.portForwards, forwardAddr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Connection not in connections (non-JWT auth path)
|
||||
s.connections[key] = &connState{
|
||||
username: username,
|
||||
remoteAddr: remoteAddr,
|
||||
portForwards: []string{forwardAddr},
|
||||
jwtUsername: s.pendingAuthJWT[newAuthKey(username, remoteAddr)],
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
|
||||
// We can't extract the SSH connection from net.Conn directly
|
||||
// Connection cleanup will happen during session cleanup or via timeout
|
||||
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
|
||||
func (s *Server) removeConnectionPortForward(remoteAddr net.Addr, forwardAddr string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
state, exists := s.connections[connKey(remoteAddr.String())]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
state.portForwards = slices.DeleteFunc(state.portForwards, func(addr string) bool {
|
||||
return addr == forwardAddr
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
|
||||
// trackedConn wraps a net.Conn to detect when it closes
|
||||
type trackedConn struct {
|
||||
net.Conn
|
||||
server *Server
|
||||
remoteAddr string
|
||||
onceClose sync.Once
|
||||
}
|
||||
|
||||
func (c *trackedConn) Close() error {
|
||||
err := c.Conn.Close()
|
||||
c.onceClose.Do(func() {
|
||||
c.server.handleConnectionClose(c.remoteAddr)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) handleConnectionClose(remoteAddr string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
key := connKey(remoteAddr)
|
||||
state, exists := s.connections[key]
|
||||
if exists && len(state.portForwards) > 0 {
|
||||
s.connLogger(state).Info("port forwarding connection closed")
|
||||
}
|
||||
delete(s.connections, key)
|
||||
}
|
||||
|
||||
func (s *Server) connLogger(state *connState) *log.Entry {
|
||||
logger := log.WithField("session", fmt.Sprintf("%s@%s", state.username, state.remoteAddr))
|
||||
if state.jwtUsername != "" {
|
||||
logger = logger.WithField("jwt_user", state.jwtUsername)
|
||||
}
|
||||
return logger
|
||||
}
|
||||
|
||||
func (s *Server) findSessionKeyByContext(ctx ssh.Context) sessionKey {
|
||||
if ctx == nil {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// Try to match by SSH connection
|
||||
sshConn := ctx.Value(ssh.ContextKeyConn)
|
||||
if sshConn == nil {
|
||||
return "unknown"
|
||||
@@ -591,19 +712,14 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Look through sessions to find one with matching connection
|
||||
for sessionKey, session := range s.sessions {
|
||||
if session.Context().Value(ssh.ContextKeyConn) == sshConn {
|
||||
for sessionKey, state := range s.sessions {
|
||||
if state.session.Context().Value(ssh.ContextKeyConn) == sshConn {
|
||||
return sessionKey
|
||||
}
|
||||
}
|
||||
|
||||
// If no session found, this might be during early connection setup
|
||||
// Return a temporary key that we'll fix up later
|
||||
if ctx.User() != "" && ctx.RemoteAddr() != nil {
|
||||
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
|
||||
log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
|
||||
return tempKey
|
||||
return sessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
@@ -644,7 +760,11 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
||||
}
|
||||
|
||||
log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr)
|
||||
return conn
|
||||
return &trackedConn{
|
||||
Conn: conn,
|
||||
server: s,
|
||||
remoteAddr: conn.RemoteAddr().String(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
||||
@@ -672,9 +792,8 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
||||
"tcpip-forward": s.tcpipForwardHandler,
|
||||
"cancel-tcpip-forward": s.cancelTcpipForwardHandler,
|
||||
},
|
||||
ConnCallback: s.connectionValidator,
|
||||
ConnectionFailedCallback: s.connectionCloseHandler,
|
||||
Version: serverVersion,
|
||||
ConnCallback: s.connectionValidator,
|
||||
Version: serverVersion,
|
||||
}
|
||||
|
||||
if s.jwtEnabled {
|
||||
@@ -690,13 +809,13 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
|
||||
func (s *Server) storeRemoteForwardListener(key forwardKey, ln net.Listener) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.remoteForwardListeners[key] = ln
|
||||
}
|
||||
|
||||
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
|
||||
func (s *Server) removeRemoteForwardListener(key forwardKey) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -714,6 +833,8 @@ func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
|
||||
}
|
||||
|
||||
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
|
||||
logger := s.getRequestLogger(ctx)
|
||||
|
||||
var payload struct {
|
||||
Host string
|
||||
Port uint32
|
||||
@@ -723,7 +844,7 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
|
||||
|
||||
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
|
||||
if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil {
|
||||
log.Debugf("channel reject error: %v", err)
|
||||
logger.Debugf("channel reject error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -733,19 +854,20 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !allowLocal {
|
||||
log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port)
|
||||
logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port)
|
||||
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// Check privilege requirements for the destination port
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
|
||||
log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
|
||||
logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
|
||||
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
|
||||
forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port)
|
||||
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
|
||||
logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
|
||||
|
||||
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user