Use net.JoinHostPort and net.SplitHostPort for IPv6-safe host:port handling (#5836)

This commit is contained in:
Viktor Liu
2026-04-10 09:10:57 +08:00
committed by GitHub
parent 0cc90e2a8a
commit f484835292
21 changed files with 193 additions and 36 deletions

View File

@@ -157,7 +157,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string {
if u.Opaque != "" {
host, port, err := net.SplitHostPort(u.Opaque)
if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Opaque)
}
@@ -165,7 +165,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string {
} else if u.Host != "" {
host, port, err := net.SplitHostPort(u.Host)
if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Host)
}

View File

@@ -286,6 +286,16 @@ func TestAnonymizeString_IPAddresses(t *testing.T) {
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1",
},
{
name: "STUN URI with IPv6",
input: "Connecting to stun:[2001:db8::ff00:42]:3478",
expect: "Connecting to stun:[2001:db8:ffff::]:3478",
},
{
name: "HTTPS URI with IPv6",
input: "Visit https://[2001:db8::ff00:42]:443/path",
expect: "Visit https://[2001:db8:ffff::]:443/path",
},
}
for _, tc := range tests {

View File

@@ -523,7 +523,7 @@ func parseHostnameAndCommand(args []string) error {
}
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
target := fmt.Sprintf("%s:%d", addr, port)
target := net.JoinHostPort(strings.Trim(addr, "[]"), strconv.Itoa(port))
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
KnownHostsFile: knownHostsFile,
IdentityFile: identityFile,

View File

@@ -13,6 +13,54 @@ import (
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
func TestConnKey_String(t *testing.T) {
tests := []struct {
name string
key ConnKey
expect string
}{
{
name: "IPv4",
key: ConnKey{
SrcIP: netip.MustParseAddr("192.168.1.1"),
DstIP: netip.MustParseAddr("10.0.0.1"),
SrcPort: 12345,
DstPort: 80,
},
expect: "192.168.1.1:12345 → 10.0.0.1:80",
},
{
name: "IPv6",
key: ConnKey{
SrcIP: netip.MustParseAddr("2001:db8::1"),
DstIP: netip.MustParseAddr("2001:db8::2"),
SrcPort: 54321,
DstPort: 443,
},
expect: "[2001:db8::1]:54321 → [2001:db8::2]:443",
},
{
name: "IPv4-mapped IPv6 unmaps",
key: ConnKey{
SrcIP: netip.MustParseAddr("::ffff:10.0.0.1"),
DstIP: netip.MustParseAddr("::ffff:10.0.0.2"),
SrcPort: 1000,
DstPort: 2000,
},
expect: "10.0.0.1:1000 → 10.0.0.2:2000",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := tc.key.String()
if got != tc.expect {
t.Errorf("got %q, want %q", got, tc.expect)
}
})
}
}
// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) {

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"time"
@@ -137,12 +138,12 @@ func (info ICMPInfo) parseOriginalPacket() string {
case nftypes.TCP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
return "TCP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort)))
case nftypes.UDP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
return "UDP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort)))
case nftypes.ICMP:
icmpType := transportData[0]

View File

@@ -5,6 +5,42 @@ import (
"testing"
)
func TestICMPConnKey_String(t *testing.T) {
tests := []struct {
name string
key ICMPConnKey
expect string
}{
{
name: "IPv4",
key: ICMPConnKey{
SrcIP: netip.MustParseAddr("192.168.1.1"),
DstIP: netip.MustParseAddr("10.0.0.1"),
ID: 1234,
},
expect: "192.168.1.1 → 10.0.0.1 (id 1234)",
},
{
name: "IPv6",
key: ICMPConnKey{
SrcIP: netip.MustParseAddr("2001:db8::1"),
DstIP: netip.MustParseAddr("2001:db8::2"),
ID: 5678,
},
expect: "2001:db8::1 → 2001:db8::2 (id 5678)",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := tc.key.String()
if got != tc.expect {
t.Errorf("got %q, want %q", got, tc.expect)
}
})
}
}
func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)

View File

@@ -2,7 +2,9 @@ package uspfilter
import (
"fmt"
"net"
"net/netip"
"strconv"
"time"
"github.com/google/gopacket"
@@ -443,7 +445,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n
trace.AddResult(StageRouteACL, msg, allowed)
if allowed && m.forwarder.Load() != nil {
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
m.addForwardingResult(trace, "proxy-remote", net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort))), true)
}
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)

View File

@@ -5,6 +5,7 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"net"
"net/url"
"os"
"os/user"
@@ -759,8 +760,7 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
return config, nil
}
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
config.ManagementURL.Scheme, defaultManagementURL.Hostname(), 443))
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s", config.ManagementURL.Scheme, net.JoinHostPort(defaultManagementURL.Hostname(), "443")))
if err != nil {
return nil, err
}

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"net"
"strconv"
"sync"
"time"
@@ -257,7 +258,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri
}
}()
turnServerAddr := fmt.Sprintf("%s:%d", uri.Host, uri.Port)
turnServerAddr := net.JoinHostPort(uri.Host, strconv.Itoa(uri.Port))
var conn net.PacketConn
switch uri.Proto {

View File

@@ -259,6 +259,9 @@ func findRandomAvailableUDPPort() (int, error) {
}
defer conn.Close()
splitAddress := strings.Split(conn.LocalAddr().String(), ":")
return strconv.Atoi(splitAddress[len(splitAddress)-1])
_, portStr, err := net.SplitHostPort(conn.LocalAddr().String())
if err != nil {
return 0, fmt.Errorf("parse local address %s: %w", conn.LocalAddr(), err)
}
return strconv.Atoi(portStr)
}

View File

@@ -0,0 +1,14 @@
package rosenpass
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestFindRandomAvailableUDPPort(t *testing.T) {
port, err := findRandomAvailableUDPPort()
require.NoError(t, err)
require.Greater(t, port, 0)
require.LessOrEqual(t, port, 65535)
}

View File

@@ -321,7 +321,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne
return
}
dest := fmt.Sprintf("%s:%d", payload.DestAddr, payload.DestPort)
dest := net.JoinHostPort(payload.DestAddr, strconv.Itoa(int(payload.DestPort)))
log.Debugf("local port forwarding: %s", dest)
backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())

View File

@@ -56,12 +56,12 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
logger := s.getRequestLogger(ctx)
if !allowLocal {
logger.Warnf("local port forwarding denied for %s:%d: disabled", dstHost, dstPort)
logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort))))
return false
}
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
logger.Warnf("local port forwarding denied for %s:%d: %v", dstHost, dstPort, err)
logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort))), err)
return false
}
@@ -71,12 +71,12 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
logger := s.getRequestLogger(ctx)
if !allowRemote {
logger.Warnf("remote port forwarding denied for %s:%d: disabled", bindHost, bindPort)
logger.Warnf("remote port forwarding denied for %s: disabled", net.JoinHostPort(bindHost, strconv.Itoa(int(bindPort))))
return false
}
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
logger.Warnf("remote port forwarding denied for %s:%d: %v", bindHost, bindPort, err)
logger.Warnf("remote port forwarding denied for %s: %v", net.JoinHostPort(bindHost, strconv.Itoa(int(bindPort))), err)
return false
}
@@ -183,15 +183,16 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *
return false, nil
}
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))
key := forwardKey(hostPort)
if s.removeRemoteForwardListener(key) {
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, payload.Port)
forwardAddr := "-R " + hostPort
s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr)
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
logger.Infof("remote port forwarding cancelled: %s", hostPort)
return true, nil
}
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port)
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
return false, nil
}
@@ -201,7 +202,7 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h
defer func() {
if err := ln.Close(); err != nil {
logger.Debugf("remote forward listener close error for %s:%d: %v", host, port, err)
logger.Debugf("remote forward listener close error for %s: %v", net.JoinHostPort(host, strconv.Itoa(int(port))), err)
}
}()
@@ -230,7 +231,7 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h
}
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
case <-ctx.Done():
logger.Debugf("remote forward listener shutting down for %s:%d", host, port)
logger.Debugf("remote forward listener shutting down for %s", net.JoinHostPort(host, strconv.Itoa(int(port))))
return
}
}
@@ -311,17 +312,17 @@ func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
}
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
key := forwardKey(net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
s.storeRemoteForwardListener(key, ln)
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, actualPort)
forwardAddr := "-R " + net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort)))
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
response := make([]byte, 4)
binary.BigEndian.PutUint32(response, actualPort)
logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort)
logger.Infof("remote port forwarding established: %s", net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort))))
return true, response
}
@@ -351,7 +352,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr)
if err != nil {
logger.Debugf("open forward channel for %s:%d: %v", host, port, err)
logger.Debugf("open forward channel for %s: %v", net.JoinHostPort(host, strconv.Itoa(int(port))), err)
_ = conn.Close()
return
}

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net"
"strconv"
"net/netip"
"slices"
"strings"
@@ -918,20 +919,21 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
s.mu.RUnlock()
if !allowLocal {
logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port)
logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
return
}
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))), err)
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
return
}
forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port)
hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))
forwardAddr := "-L " + hostPort
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
logger.Infof("local port forwarding: %s", hostPort)
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
}

View File

@@ -378,7 +378,7 @@ func (c *CombinedConfig) autoConfigureClientSettings(exposedProto, exposedHost,
// Auto-configure local STUN servers for all ports
for _, port := range c.Server.StunPorts {
c.Management.Stuns = append(c.Management.Stuns, HostConfig{
URI: fmt.Sprintf("stun:%s:%d", exposedHost, port),
URI: "stun:" + net.JoinHostPort(strings.Trim(exposedHost, "[]"), fmt.Sprintf("%d", port)),
})
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"math/rand/v2"
"net"
"net/http"
"os"
"slices"
@@ -1102,7 +1103,7 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
serviceURL := "https://" + svc.Domain
if service.IsL4Protocol(svc.Mode) {
serviceURL = fmt.Sprintf("%s://%s:%d", svc.Mode, svc.Domain, svc.ListenPort)
serviceURL = fmt.Sprintf("%s://%s", svc.Mode, net.JoinHostPort(svc.Domain, strconv.Itoa(int(svc.ListenPort))))
}
return &service.ExposeServiceResponse{

View File

@@ -3,7 +3,10 @@ package dns
import (
"encoding/json"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
@@ -201,7 +204,7 @@ func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.R
func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) {
var nsList []nbdns.NameServer
for _, apiNS := range apiNSList {
parsed, err := nbdns.ParseNameServerURL(fmt.Sprintf("%s://%s:%d", apiNS.NsType, apiNS.Ip, apiNS.Port))
parsed, err := nbdns.ParseNameServerURL(fmt.Sprintf("%s://%s", apiNS.NsType, net.JoinHostPort(strings.Trim(apiNS.Ip, "[]"), strconv.Itoa(apiNS.Port))))
if err != nil {
return nil, err
}

View File

@@ -233,3 +233,37 @@ func TestNameserversHandlers(t *testing.T) {
})
}
}
func TestToServerNSList_IPv6(t *testing.T) {
tests := []struct {
name string
input []api.Nameserver
expectIP netip.Addr
}{
{
name: "IPv4",
input: []api.Nameserver{
{Ip: "1.1.1.1", NsType: "udp", Port: 53},
},
expectIP: netip.MustParseAddr("1.1.1.1"),
},
{
name: "IPv6",
input: []api.Nameserver{
{Ip: "2001:4860:4860::8888", NsType: "udp", Port: 53},
},
expectIP: netip.MustParseAddr("2001:4860:4860::8888"),
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result, err := toServerNSList(tc.input)
assert.NoError(t, err)
if assert.Len(t, result, 1) {
assert.Equal(t, tc.expectIP, result[0].IP)
assert.Equal(t, 53, result[0].Port)
}
})
}
}

View File

@@ -337,7 +337,7 @@ func runTurnDataTransfer(t *testing.T, testData []byte) time.Duration {
func getTurnClient(t *testing.T, address string, conn net.Conn) (*turn.Client, error) {
t.Helper()
// Dial TURN Server
addrStr := fmt.Sprintf("%s:%d", address, 443)
addrStr := net.JoinHostPort(address, "443")
fac := logging.NewDefaultLoggerFactory()
//fac.DefaultLogLevel = logging.LogLevelTrace

View File

@@ -52,7 +52,7 @@ func AllocateTurnClient(serverAddr string) *TurnConn {
func getTurnClient(address string, conn net.Conn) (*turn.Client, error) {
// Dial TURN Server
addrStr := fmt.Sprintf("%s:%d", address, 443)
addrStr := net.JoinHostPort(address, "443")
fac := logging.NewDefaultLoggerFactory()
//fac.DefaultLogLevel = logging.LogLevelTrace

View File

@@ -3,6 +3,7 @@ package server
import (
"context"
"encoding/json"
"net"
"net/http"
"net/http/httptest"
"runtime"
@@ -52,7 +53,7 @@ func Test_S3HandlerGetUploadURL(t *testing.T) {
hostIP, err := c.Host(ctx)
require.NoError(t, err)
awsEndpoint := "http://" + hostIP + ":" + mappedPort.Port()
awsEndpoint := "http://" + net.JoinHostPort(hostIP, mappedPort.Port())
t.Setenv("AWS_REGION", awsRegion)
t.Setenv("AWS_ENDPOINT_URL", awsEndpoint)