Use net.JoinHostPort and net.SplitHostPort for IPv6-safe host:port handling (#5836)

This commit is contained in:
Viktor Liu
2026-04-10 09:10:57 +08:00
committed by GitHub
parent 0cc90e2a8a
commit f484835292
21 changed files with 193 additions and 36 deletions

View File

@@ -56,12 +56,12 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
logger := s.getRequestLogger(ctx)
if !allowLocal {
logger.Warnf("local port forwarding denied for %s:%d: disabled", dstHost, dstPort)
logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort))))
return false
}
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
logger.Warnf("local port forwarding denied for %s:%d: %v", dstHost, dstPort, err)
logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort))), err)
return false
}
@@ -71,12 +71,12 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
logger := s.getRequestLogger(ctx)
if !allowRemote {
logger.Warnf("remote port forwarding denied for %s:%d: disabled", bindHost, bindPort)
logger.Warnf("remote port forwarding denied for %s: disabled", net.JoinHostPort(bindHost, strconv.Itoa(int(bindPort))))
return false
}
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
logger.Warnf("remote port forwarding denied for %s:%d: %v", bindHost, bindPort, err)
logger.Warnf("remote port forwarding denied for %s: %v", net.JoinHostPort(bindHost, strconv.Itoa(int(bindPort))), err)
return false
}
@@ -183,15 +183,16 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *
return false, nil
}
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))
key := forwardKey(hostPort)
if s.removeRemoteForwardListener(key) {
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, payload.Port)
forwardAddr := "-R " + hostPort
s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr)
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
logger.Infof("remote port forwarding cancelled: %s", hostPort)
return true, nil
}
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port)
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
return false, nil
}
@@ -201,7 +202,7 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h
defer func() {
if err := ln.Close(); err != nil {
logger.Debugf("remote forward listener close error for %s:%d: %v", host, port, err)
logger.Debugf("remote forward listener close error for %s: %v", net.JoinHostPort(host, strconv.Itoa(int(port))), err)
}
}()
@@ -230,7 +231,7 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h
}
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
case <-ctx.Done():
logger.Debugf("remote forward listener shutting down for %s:%d", host, port)
logger.Debugf("remote forward listener shutting down for %s", net.JoinHostPort(host, strconv.Itoa(int(port))))
return
}
}
@@ -311,17 +312,17 @@ func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
}
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
key := forwardKey(net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
s.storeRemoteForwardListener(key, ln)
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, actualPort)
forwardAddr := "-R " + net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort)))
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
response := make([]byte, 4)
binary.BigEndian.PutUint32(response, actualPort)
logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort)
logger.Infof("remote port forwarding established: %s", net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort))))
return true, response
}
@@ -351,7 +352,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr)
if err != nil {
logger.Debugf("open forward channel for %s:%d: %v", host, port, err)
logger.Debugf("open forward channel for %s: %v", net.JoinHostPort(host, strconv.Itoa(int(port))), err)
_ = conn.Close()
return
}

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net"
"strconv"
"net/netip"
"slices"
"strings"
@@ -918,20 +919,21 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
s.mu.RUnlock()
if !allowLocal {
logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port)
logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
return
}
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))), err)
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
return
}
forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port)
hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))
forwardAddr := "-L " + hostPort
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
logger.Infof("local port forwarding: %s", hostPort)
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
}