mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 00:06:38 +00:00
Use net.JoinHostPort and net.SplitHostPort for IPv6-safe host:port handling (#5836)
This commit is contained in:
@@ -56,12 +56,12 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
|
||||
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
|
||||
logger := s.getRequestLogger(ctx)
|
||||
if !allowLocal {
|
||||
logger.Warnf("local port forwarding denied for %s:%d: disabled", dstHost, dstPort)
|
||||
logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort))))
|
||||
return false
|
||||
}
|
||||
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
|
||||
logger.Warnf("local port forwarding denied for %s:%d: %v", dstHost, dstPort, err)
|
||||
logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort))), err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -71,12 +71,12 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
|
||||
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
|
||||
logger := s.getRequestLogger(ctx)
|
||||
if !allowRemote {
|
||||
logger.Warnf("remote port forwarding denied for %s:%d: disabled", bindHost, bindPort)
|
||||
logger.Warnf("remote port forwarding denied for %s: disabled", net.JoinHostPort(bindHost, strconv.Itoa(int(bindPort))))
|
||||
return false
|
||||
}
|
||||
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
|
||||
logger.Warnf("remote port forwarding denied for %s:%d: %v", bindHost, bindPort, err)
|
||||
logger.Warnf("remote port forwarding denied for %s: %v", net.JoinHostPort(bindHost, strconv.Itoa(int(bindPort))), err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -183,15 +183,16 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *
|
||||
return false, nil
|
||||
}
|
||||
|
||||
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
|
||||
hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))
|
||||
key := forwardKey(hostPort)
|
||||
if s.removeRemoteForwardListener(key) {
|
||||
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, payload.Port)
|
||||
forwardAddr := "-R " + hostPort
|
||||
s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr)
|
||||
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
|
||||
logger.Infof("remote port forwarding cancelled: %s", hostPort)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port)
|
||||
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -201,7 +202,7 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h
|
||||
|
||||
defer func() {
|
||||
if err := ln.Close(); err != nil {
|
||||
logger.Debugf("remote forward listener close error for %s:%d: %v", host, port, err)
|
||||
logger.Debugf("remote forward listener close error for %s: %v", net.JoinHostPort(host, strconv.Itoa(int(port))), err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -230,7 +231,7 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h
|
||||
}
|
||||
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
|
||||
case <-ctx.Done():
|
||||
logger.Debugf("remote forward listener shutting down for %s:%d", host, port)
|
||||
logger.Debugf("remote forward listener shutting down for %s", net.JoinHostPort(host, strconv.Itoa(int(port))))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -311,17 +312,17 @@ func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn
|
||||
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
|
||||
}
|
||||
|
||||
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
|
||||
key := forwardKey(net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
|
||||
s.storeRemoteForwardListener(key, ln)
|
||||
|
||||
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, actualPort)
|
||||
forwardAddr := "-R " + net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort)))
|
||||
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
|
||||
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
|
||||
|
||||
response := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(response, actualPort)
|
||||
|
||||
logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort)
|
||||
logger.Infof("remote port forwarding established: %s", net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort))))
|
||||
return true, response
|
||||
}
|
||||
|
||||
@@ -351,7 +352,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
|
||||
|
||||
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr)
|
||||
if err != nil {
|
||||
logger.Debugf("open forward channel for %s:%d: %v", host, port, err)
|
||||
logger.Debugf("open forward channel for %s: %v", net.JoinHostPort(host, strconv.Itoa(int(port))), err)
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -918,20 +919,21 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !allowLocal {
|
||||
logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port)
|
||||
logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
|
||||
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
|
||||
logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
|
||||
logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))), err)
|
||||
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
|
||||
return
|
||||
}
|
||||
|
||||
forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port)
|
||||
hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))
|
||||
forwardAddr := "-L " + hostPort
|
||||
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
|
||||
logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
|
||||
logger.Infof("local port forwarding: %s", hostPort)
|
||||
|
||||
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user