Files
netbird/client/ssh/server/port_forwarding.go
Viktor Liu f012fb8592 [client] Add port forwarding to ssh proxy (#5031)
* Implement port forwarding for the ssh proxy

* Allow user switching for port forwarding
2026-01-07 12:18:04 +08:00

391 lines
12 KiB
Go

// Package server implements port forwarding for the SSH server.
//
// Security note: Port forwarding runs in the main server process without privilege separation.
// The attack surface is primarily io.Copy through well-tested standard library code, making it
// lower risk than shell execution which uses privilege-separated child processes. We enforce
// user-level port restrictions: non-privileged users cannot bind to ports < 1024.
package server
import (
"encoding/binary"
"fmt"
"net"
"runtime"
"strconv"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
)
const privilegedPortThreshold = 1024
// sessionKey uniquely identifies an SSH session
type sessionKey string
// forwardKey uniquely identifies a port forwarding listener
type forwardKey string
// tcpipForwardMsg represents the structure for tcpip-forward SSH requests
type tcpipForwardMsg struct {
Host string
Port uint32
}
// SetAllowLocalPortForwarding configures local port forwarding
func (s *Server) SetAllowLocalPortForwarding(allow bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.allowLocalPortForwarding = allow
}
// SetAllowRemotePortForwarding configures remote port forwarding
func (s *Server) SetAllowRemotePortForwarding(allow bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.allowRemotePortForwarding = allow
}
// configurePortForwarding sets up port forwarding callbacks
func (s *Server) configurePortForwarding(server *ssh.Server) {
allowLocal := s.allowLocalPortForwarding
allowRemote := s.allowRemotePortForwarding
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
logger := s.getRequestLogger(ctx)
if !allowLocal {
logger.Warnf("local port forwarding denied for %s:%d: disabled", dstHost, dstPort)
return false
}
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
logger.Warnf("local port forwarding denied for %s:%d: %v", dstHost, dstPort, err)
return false
}
return true
}
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
logger := s.getRequestLogger(ctx)
if !allowRemote {
logger.Warnf("remote port forwarding denied for %s:%d: disabled", bindHost, bindPort)
return false
}
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
logger.Warnf("remote port forwarding denied for %s:%d: %v", bindHost, bindPort, err)
return false
}
return true
}
log.Debugf("SSH server configured with local_forwarding=%v, remote_forwarding=%v", allowLocal, allowRemote)
}
// checkPortForwardingPrivileges validates privilege requirements for port forwarding operations.
// For remote port forwarding (binding), it enforces that non-privileged users cannot bind to
// ports below 1024, mirroring the restriction they would face if binding directly.
//
// Note: FeatureSupportsUserSwitch is true because we accept requests from any authenticated user,
// though we don't actually switch users - port forwarding runs in the server process. The resolved
// user is used for privileged port access checks.
func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error {
if ctx == nil {
return fmt.Errorf("%s port forwarding denied: no context", forwardType)
}
result := s.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: ctx.User(),
FeatureSupportsUserSwitch: true,
FeatureName: forwardType + " port forwarding",
})
if !result.Allowed {
return result.Error
}
if err := s.checkPrivilegedPortAccess(forwardType, port, result); err != nil {
return err
}
return nil
}
// checkPrivilegedPortAccess enforces that non-privileged users cannot bind to privileged ports.
// This applies to remote port forwarding where the server binds a port on behalf of the user.
// On Windows, there is no privileged port restriction, so this check is skipped.
func (s *Server) checkPrivilegedPortAccess(forwardType string, port uint32, result PrivilegeCheckResult) error {
if runtime.GOOS == "windows" {
return nil
}
isBindOperation := forwardType == "remote" || forwardType == "tcpip-forward"
if !isBindOperation {
return nil
}
// Port 0 means "pick any available port", which will be >= 1024
if port == 0 || port >= privilegedPortThreshold {
return nil
}
if result.User != nil && isPrivilegedUsername(result.User.Username) {
return nil
}
username := "unknown"
if result.User != nil {
username = result.User.Username
}
return fmt.Errorf("user %s cannot bind to privileged port %d (requires root)", username, port)
}
// tcpipForwardHandler handles tcpip-forward requests for remote port forwarding.
func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
logger := s.getRequestLogger(ctx)
if !s.isRemotePortForwardingAllowed() {
logger.Warnf("tcpip-forward request denied: remote port forwarding disabled")
return false, nil
}
payload, err := s.parseTcpipForwardRequest(req)
if err != nil {
logger.Errorf("tcpip-forward unmarshal error: %v", err)
return false, nil
}
if err := s.checkPortForwardingPrivileges(ctx, "tcpip-forward", payload.Port); err != nil {
logger.Warnf("tcpip-forward denied: %v", err)
return false, nil
}
sshConn, err := s.getSSHConnection(ctx)
if err != nil {
logger.Warnf("tcpip-forward request denied: %v", err)
return false, nil
}
return s.setupDirectForward(ctx, logger, sshConn, payload)
}
// cancelTcpipForwardHandler handles cancel-tcpip-forward requests.
func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
logger := s.getRequestLogger(ctx)
var payload tcpipForwardMsg
if err := cryptossh.Unmarshal(req.Payload, &payload); err != nil {
logger.Errorf("cancel-tcpip-forward unmarshal error: %v", err)
return false, nil
}
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
if s.removeRemoteForwardListener(key) {
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, payload.Port)
s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr)
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
return true, nil
}
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port)
return false, nil
}
// handleRemoteForwardListener handles incoming connections for remote port forwarding.
func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) {
logger := s.getRequestLogger(ctx)
defer func() {
if err := ln.Close(); err != nil {
logger.Debugf("remote forward listener close error for %s:%d: %v", host, port, err)
}
}()
acceptChan := make(chan acceptResult, 1)
go func() {
for {
conn, err := ln.Accept()
select {
case acceptChan <- acceptResult{conn: conn, err: err}:
if err != nil {
return
}
case <-ctx.Done():
return
}
}
}()
for {
select {
case result := <-acceptChan:
if result.err != nil {
logger.Debugf("remote forward accept error: %v", result.err)
return
}
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
case <-ctx.Done():
logger.Debugf("remote forward listener shutting down for %s:%d", host, port)
return
}
}
}
// getRequestLogger creates a logger with session/conn and jwt_user context
func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry {
sessionKey := s.findSessionKeyByContext(ctx)
s.mu.RLock()
defer s.mu.RUnlock()
if state, exists := s.sessions[sessionKey]; exists {
logger := log.WithField("session", sessionKey)
if state.jwtUsername != "" {
logger = logger.WithField("jwt_user", state.jwtUsername)
}
return logger
}
if ctx.RemoteAddr() != nil {
if connState, exists := s.connections[connKey(ctx.RemoteAddr().String())]; exists {
return s.connLogger(connState)
}
}
remoteAddr := "unknown"
if ctx.RemoteAddr() != nil {
remoteAddr = ctx.RemoteAddr().String()
}
return log.WithField("session", fmt.Sprintf("%s@%s", ctx.User(), remoteAddr))
}
// isRemotePortForwardingAllowed checks if remote port forwarding is enabled
func (s *Server) isRemotePortForwardingAllowed() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.allowRemotePortForwarding
}
// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled
func (s *Server) isPortForwardingEnabled() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.allowLocalPortForwarding || s.allowRemotePortForwarding
}
// parseTcpipForwardRequest parses the SSH request payload
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
var payload tcpipForwardMsg
err := cryptossh.Unmarshal(req.Payload, &payload)
return &payload, err
}
// getSSHConnection extracts SSH connection from context
func (s *Server) getSSHConnection(ctx ssh.Context) (*cryptossh.ServerConn, error) {
if ctx == nil {
return nil, fmt.Errorf("no context")
}
sshConnValue := ctx.Value(ssh.ContextKeyConn)
if sshConnValue == nil {
return nil, fmt.Errorf("no SSH connection in context")
}
sshConn, ok := sshConnValue.(*cryptossh.ServerConn)
if !ok || sshConn == nil {
return nil, fmt.Errorf("invalid SSH connection in context")
}
return sshConn, nil
}
// setupDirectForward sets up a direct port forward
func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn *cryptossh.ServerConn, payload *tcpipForwardMsg) (bool, []byte) {
bindAddr := net.JoinHostPort(payload.Host, strconv.FormatUint(uint64(payload.Port), 10))
ln, err := net.Listen("tcp", bindAddr)
if err != nil {
logger.Errorf("tcpip-forward listen failed on %s: %v", bindAddr, err)
return false, nil
}
actualPort := payload.Port
if payload.Port == 0 {
tcpAddr := ln.Addr().(*net.TCPAddr)
actualPort = uint32(tcpAddr.Port)
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
}
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
s.storeRemoteForwardListener(key, ln)
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, actualPort)
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
response := make([]byte, 4)
binary.BigEndian.PutUint32(response, actualPort)
logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort)
return true, response
}
// acceptResult holds the result of a listener Accept() call
type acceptResult struct {
conn net.Conn
err error
}
// handleRemoteForwardConnection handles a single remote port forwarding connection
func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) {
logger := s.getRequestLogger(ctx)
sshConn, ok := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
if !ok || sshConn == nil {
logger.Debugf("remote forward: no SSH connection in context")
_ = conn.Close()
return
}
remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
if !ok {
logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr())
_ = conn.Close()
return
}
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr)
if err != nil {
logger.Debugf("open forward channel for %s:%d: %v", host, port, err)
_ = conn.Close()
return
}
nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel)
}
// openForwardChannel creates an SSH forwarded-tcpip channel
func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr) (cryptossh.Channel, error) {
payload := struct {
ConnectedAddress string
ConnectedPort uint32
OriginatorAddress string
OriginatorPort uint32
}{
ConnectedAddress: host,
ConnectedPort: port,
OriginatorAddress: remoteAddr.IP.String(),
OriginatorPort: uint32(remoteAddr.Port),
}
channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", cryptossh.Marshal(&payload))
if err != nil {
return nil, fmt.Errorf("open SSH channel: %w", err)
}
go cryptossh.DiscardRequests(reqs)
return channel, nil
}