mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
Use net.JoinHostPort and net.SplitHostPort for IPv6-safe host:port handling (#5836)
This commit is contained in:
@@ -157,7 +157,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string {
|
|||||||
if u.Opaque != "" {
|
if u.Opaque != "" {
|
||||||
host, port, err := net.SplitHostPort(u.Opaque)
|
host, port, err := net.SplitHostPort(u.Opaque)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
|
anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port)
|
||||||
} else {
|
} else {
|
||||||
anonymizedHost = a.AnonymizeDomain(u.Opaque)
|
anonymizedHost = a.AnonymizeDomain(u.Opaque)
|
||||||
}
|
}
|
||||||
@@ -165,7 +165,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string {
|
|||||||
} else if u.Host != "" {
|
} else if u.Host != "" {
|
||||||
host, port, err := net.SplitHostPort(u.Host)
|
host, port, err := net.SplitHostPort(u.Host)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
|
anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port)
|
||||||
} else {
|
} else {
|
||||||
anonymizedHost = a.AnonymizeDomain(u.Host)
|
anonymizedHost = a.AnonymizeDomain(u.Host)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -286,6 +286,16 @@ func TestAnonymizeString_IPAddresses(t *testing.T) {
|
|||||||
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
|
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
|
||||||
expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1",
|
expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "STUN URI with IPv6",
|
||||||
|
input: "Connecting to stun:[2001:db8::ff00:42]:3478",
|
||||||
|
expect: "Connecting to stun:[2001:db8:ffff::]:3478",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTTPS URI with IPv6",
|
||||||
|
input: "Visit https://[2001:db8::ff00:42]:443/path",
|
||||||
|
expect: "Visit https://[2001:db8:ffff::]:443/path",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
|
|||||||
@@ -523,7 +523,7 @@ func parseHostnameAndCommand(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
||||||
target := fmt.Sprintf("%s:%d", addr, port)
|
target := net.JoinHostPort(strings.Trim(addr, "[]"), strconv.Itoa(port))
|
||||||
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
|
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
|
||||||
KnownHostsFile: knownHostsFile,
|
KnownHostsFile: knownHostsFile,
|
||||||
IdentityFile: identityFile,
|
IdentityFile: identityFile,
|
||||||
|
|||||||
@@ -13,6 +13,54 @@ import (
|
|||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
|
func TestConnKey_String(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key ConnKey
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IPv4",
|
||||||
|
key: ConnKey{
|
||||||
|
SrcIP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
DstIP: netip.MustParseAddr("10.0.0.1"),
|
||||||
|
SrcPort: 12345,
|
||||||
|
DstPort: 80,
|
||||||
|
},
|
||||||
|
expect: "192.168.1.1:12345 → 10.0.0.1:80",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6",
|
||||||
|
key: ConnKey{
|
||||||
|
SrcIP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
DstIP: netip.MustParseAddr("2001:db8::2"),
|
||||||
|
SrcPort: 54321,
|
||||||
|
DstPort: 443,
|
||||||
|
},
|
||||||
|
expect: "[2001:db8::1]:54321 → [2001:db8::2]:443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4-mapped IPv6 unmaps",
|
||||||
|
key: ConnKey{
|
||||||
|
SrcIP: netip.MustParseAddr("::ffff:10.0.0.1"),
|
||||||
|
DstIP: netip.MustParseAddr("::ffff:10.0.0.2"),
|
||||||
|
SrcPort: 1000,
|
||||||
|
DstPort: 2000,
|
||||||
|
},
|
||||||
|
expect: "10.0.0.1:1000 → 10.0.0.2:2000",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := tc.key.String()
|
||||||
|
if got != tc.expect {
|
||||||
|
t.Errorf("got %q, want %q", got, tc.expect)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Memory pressure tests
|
// Memory pressure tests
|
||||||
func BenchmarkMemoryPressure(b *testing.B) {
|
func BenchmarkMemoryPressure(b *testing.B) {
|
||||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -137,12 +138,12 @@ func (info ICMPInfo) parseOriginalPacket() string {
|
|||||||
case nftypes.TCP:
|
case nftypes.TCP:
|
||||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||||
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||||
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
return "TCP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort)))
|
||||||
|
|
||||||
case nftypes.UDP:
|
case nftypes.UDP:
|
||||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||||
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||||
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
return "UDP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort)))
|
||||||
|
|
||||||
case nftypes.ICMP:
|
case nftypes.ICMP:
|
||||||
icmpType := transportData[0]
|
icmpType := transportData[0]
|
||||||
|
|||||||
@@ -5,6 +5,42 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestICMPConnKey_String(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key ICMPConnKey
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IPv4",
|
||||||
|
key: ICMPConnKey{
|
||||||
|
SrcIP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
DstIP: netip.MustParseAddr("10.0.0.1"),
|
||||||
|
ID: 1234,
|
||||||
|
},
|
||||||
|
expect: "192.168.1.1 → 10.0.0.1 (id 1234)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6",
|
||||||
|
key: ICMPConnKey{
|
||||||
|
SrcIP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
DstIP: netip.MustParseAddr("2001:db8::2"),
|
||||||
|
ID: 5678,
|
||||||
|
},
|
||||||
|
expect: "2001:db8::1 → 2001:db8::2 (id 5678)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := tc.key.String()
|
||||||
|
if got != tc.expect {
|
||||||
|
t.Errorf("got %q, want %q", got, tc.expect)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkICMPTracker(b *testing.B) {
|
func BenchmarkICMPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@@ -443,7 +445,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n
|
|||||||
trace.AddResult(StageRouteACL, msg, allowed)
|
trace.AddResult(StageRouteACL, msg, allowed)
|
||||||
|
|
||||||
if allowed && m.forwarder.Load() != nil {
|
if allowed && m.forwarder.Load() != nil {
|
||||||
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
m.addForwardingResult(trace, "proxy-remote", net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort))), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
|
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/user"
|
"os/user"
|
||||||
@@ -759,8 +760,7 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
|
|||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
|
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s", config.ManagementURL.Scheme, net.JoinHostPort(defaultManagementURL.Hostname(), "443")))
|
||||||
config.ManagementURL.Scheme, defaultManagementURL.Hostname(), 443))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -257,7 +258,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
turnServerAddr := fmt.Sprintf("%s:%d", uri.Host, uri.Port)
|
turnServerAddr := net.JoinHostPort(uri.Host, strconv.Itoa(uri.Port))
|
||||||
|
|
||||||
var conn net.PacketConn
|
var conn net.PacketConn
|
||||||
switch uri.Proto {
|
switch uri.Proto {
|
||||||
|
|||||||
@@ -259,6 +259,9 @@ func findRandomAvailableUDPPort() (int, error) {
|
|||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
splitAddress := strings.Split(conn.LocalAddr().String(), ":")
|
_, portStr, err := net.SplitHostPort(conn.LocalAddr().String())
|
||||||
return strconv.Atoi(splitAddress[len(splitAddress)-1])
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parse local address %s: %w", conn.LocalAddr(), err)
|
||||||
|
}
|
||||||
|
return strconv.Atoi(portStr)
|
||||||
}
|
}
|
||||||
|
|||||||
14
client/internal/rosenpass/manager_test.go
Normal file
14
client/internal/rosenpass/manager_test.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package rosenpass
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFindRandomAvailableUDPPort(t *testing.T) {
|
||||||
|
port, err := findRandomAvailableUDPPort()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Greater(t, port, 0)
|
||||||
|
require.LessOrEqual(t, port, 65535)
|
||||||
|
}
|
||||||
@@ -321,7 +321,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dest := fmt.Sprintf("%s:%d", payload.DestAddr, payload.DestPort)
|
dest := net.JoinHostPort(payload.DestAddr, strconv.Itoa(int(payload.DestPort)))
|
||||||
log.Debugf("local port forwarding: %s", dest)
|
log.Debugf("local port forwarding: %s", dest)
|
||||||
|
|
||||||
backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())
|
backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())
|
||||||
|
|||||||
@@ -56,12 +56,12 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
|
|||||||
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
|
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
|
||||||
logger := s.getRequestLogger(ctx)
|
logger := s.getRequestLogger(ctx)
|
||||||
if !allowLocal {
|
if !allowLocal {
|
||||||
logger.Warnf("local port forwarding denied for %s:%d: disabled", dstHost, dstPort)
|
logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort))))
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
|
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
|
||||||
logger.Warnf("local port forwarding denied for %s:%d: %v", dstHost, dstPort, err)
|
logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort))), err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,12 +71,12 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
|
|||||||
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
|
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
|
||||||
logger := s.getRequestLogger(ctx)
|
logger := s.getRequestLogger(ctx)
|
||||||
if !allowRemote {
|
if !allowRemote {
|
||||||
logger.Warnf("remote port forwarding denied for %s:%d: disabled", bindHost, bindPort)
|
logger.Warnf("remote port forwarding denied for %s: disabled", net.JoinHostPort(bindHost, strconv.Itoa(int(bindPort))))
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
|
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
|
||||||
logger.Warnf("remote port forwarding denied for %s:%d: %v", bindHost, bindPort, err)
|
logger.Warnf("remote port forwarding denied for %s: %v", net.JoinHostPort(bindHost, strconv.Itoa(int(bindPort))), err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -183,15 +183,16 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
|
hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))
|
||||||
|
key := forwardKey(hostPort)
|
||||||
if s.removeRemoteForwardListener(key) {
|
if s.removeRemoteForwardListener(key) {
|
||||||
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, payload.Port)
|
forwardAddr := "-R " + hostPort
|
||||||
s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr)
|
s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr)
|
||||||
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
|
logger.Infof("remote port forwarding cancelled: %s", hostPort)
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port)
|
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,7 +202,7 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h
|
|||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := ln.Close(); err != nil {
|
if err := ln.Close(); err != nil {
|
||||||
logger.Debugf("remote forward listener close error for %s:%d: %v", host, port, err)
|
logger.Debugf("remote forward listener close error for %s: %v", net.JoinHostPort(host, strconv.Itoa(int(port))), err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -230,7 +231,7 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h
|
|||||||
}
|
}
|
||||||
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
|
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
logger.Debugf("remote forward listener shutting down for %s:%d", host, port)
|
logger.Debugf("remote forward listener shutting down for %s", net.JoinHostPort(host, strconv.Itoa(int(port))))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -311,17 +312,17 @@ func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn
|
|||||||
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
|
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
|
key := forwardKey(net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
|
||||||
s.storeRemoteForwardListener(key, ln)
|
s.storeRemoteForwardListener(key, ln)
|
||||||
|
|
||||||
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, actualPort)
|
forwardAddr := "-R " + net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort)))
|
||||||
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
|
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
|
||||||
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
|
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
|
||||||
|
|
||||||
response := make([]byte, 4)
|
response := make([]byte, 4)
|
||||||
binary.BigEndian.PutUint32(response, actualPort)
|
binary.BigEndian.PutUint32(response, actualPort)
|
||||||
|
|
||||||
logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort)
|
logger.Infof("remote port forwarding established: %s", net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort))))
|
||||||
return true, response
|
return true, response
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -351,7 +352,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
|
|||||||
|
|
||||||
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr)
|
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debugf("open forward channel for %s:%d: %v", host, port, err)
|
logger.Debugf("open forward channel for %s: %v", net.JoinHostPort(host, strconv.Itoa(int(port))), err)
|
||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"strconv"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -918,20 +919,21 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
|
|||||||
s.mu.RUnlock()
|
s.mu.RUnlock()
|
||||||
|
|
||||||
if !allowLocal {
|
if !allowLocal {
|
||||||
logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port)
|
logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
|
||||||
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
|
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
|
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
|
||||||
logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
|
logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))), err)
|
||||||
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
|
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port)
|
hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))
|
||||||
|
forwardAddr := "-L " + hostPort
|
||||||
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
|
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
|
||||||
logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
|
logger.Infof("local port forwarding: %s", hostPort)
|
||||||
|
|
||||||
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
|
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -378,7 +378,7 @@ func (c *CombinedConfig) autoConfigureClientSettings(exposedProto, exposedHost,
|
|||||||
// Auto-configure local STUN servers for all ports
|
// Auto-configure local STUN servers for all ports
|
||||||
for _, port := range c.Server.StunPorts {
|
for _, port := range c.Server.StunPorts {
|
||||||
c.Management.Stuns = append(c.Management.Stuns, HostConfig{
|
c.Management.Stuns = append(c.Management.Stuns, HostConfig{
|
||||||
URI: fmt.Sprintf("stun:%s:%d", exposedHost, port),
|
URI: "stun:" + net.JoinHostPort(strings.Trim(exposedHost, "[]"), fmt.Sprintf("%d", port)),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -1102,7 +1103,7 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
|
|||||||
|
|
||||||
serviceURL := "https://" + svc.Domain
|
serviceURL := "https://" + svc.Domain
|
||||||
if service.IsL4Protocol(svc.Mode) {
|
if service.IsL4Protocol(svc.Mode) {
|
||||||
serviceURL = fmt.Sprintf("%s://%s:%d", svc.Mode, svc.Domain, svc.ListenPort)
|
serviceURL = fmt.Sprintf("%s://%s", svc.Mode, net.JoinHostPort(svc.Domain, strconv.Itoa(int(svc.ListenPort))))
|
||||||
}
|
}
|
||||||
|
|
||||||
return &service.ExposeServiceResponse{
|
return &service.ExposeServiceResponse{
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -201,7 +204,7 @@ func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.R
|
|||||||
func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) {
|
func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) {
|
||||||
var nsList []nbdns.NameServer
|
var nsList []nbdns.NameServer
|
||||||
for _, apiNS := range apiNSList {
|
for _, apiNS := range apiNSList {
|
||||||
parsed, err := nbdns.ParseNameServerURL(fmt.Sprintf("%s://%s:%d", apiNS.NsType, apiNS.Ip, apiNS.Port))
|
parsed, err := nbdns.ParseNameServerURL(fmt.Sprintf("%s://%s", apiNS.NsType, net.JoinHostPort(strings.Trim(apiNS.Ip, "[]"), strconv.Itoa(apiNS.Port))))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -233,3 +233,37 @@ func TestNameserversHandlers(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToServerNSList_IPv6(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []api.Nameserver
|
||||||
|
expectIP netip.Addr
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IPv4",
|
||||||
|
input: []api.Nameserver{
|
||||||
|
{Ip: "1.1.1.1", NsType: "udp", Port: 53},
|
||||||
|
},
|
||||||
|
expectIP: netip.MustParseAddr("1.1.1.1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6",
|
||||||
|
input: []api.Nameserver{
|
||||||
|
{Ip: "2001:4860:4860::8888", NsType: "udp", Port: 53},
|
||||||
|
},
|
||||||
|
expectIP: netip.MustParseAddr("2001:4860:4860::8888"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result, err := toServerNSList(tc.input)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
if assert.Len(t, result, 1) {
|
||||||
|
assert.Equal(t, tc.expectIP, result[0].IP)
|
||||||
|
assert.Equal(t, 53, result[0].Port)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -337,7 +337,7 @@ func runTurnDataTransfer(t *testing.T, testData []byte) time.Duration {
|
|||||||
func getTurnClient(t *testing.T, address string, conn net.Conn) (*turn.Client, error) {
|
func getTurnClient(t *testing.T, address string, conn net.Conn) (*turn.Client, error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
// Dial TURN Server
|
// Dial TURN Server
|
||||||
addrStr := fmt.Sprintf("%s:%d", address, 443)
|
addrStr := net.JoinHostPort(address, "443")
|
||||||
|
|
||||||
fac := logging.NewDefaultLoggerFactory()
|
fac := logging.NewDefaultLoggerFactory()
|
||||||
//fac.DefaultLogLevel = logging.LogLevelTrace
|
//fac.DefaultLogLevel = logging.LogLevelTrace
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ func AllocateTurnClient(serverAddr string) *TurnConn {
|
|||||||
|
|
||||||
func getTurnClient(address string, conn net.Conn) (*turn.Client, error) {
|
func getTurnClient(address string, conn net.Conn) (*turn.Client, error) {
|
||||||
// Dial TURN Server
|
// Dial TURN Server
|
||||||
addrStr := fmt.Sprintf("%s:%d", address, 443)
|
addrStr := net.JoinHostPort(address, "443")
|
||||||
|
|
||||||
fac := logging.NewDefaultLoggerFactory()
|
fac := logging.NewDefaultLoggerFactory()
|
||||||
//fac.DefaultLogLevel = logging.LogLevelTrace
|
//fac.DefaultLogLevel = logging.LogLevelTrace
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -52,7 +53,7 @@ func Test_S3HandlerGetUploadURL(t *testing.T) {
|
|||||||
hostIP, err := c.Host(ctx)
|
hostIP, err := c.Host(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
awsEndpoint := "http://" + hostIP + ":" + mappedPort.Port()
|
awsEndpoint := "http://" + net.JoinHostPort(hostIP, mappedPort.Port())
|
||||||
|
|
||||||
t.Setenv("AWS_REGION", awsRegion)
|
t.Setenv("AWS_REGION", awsRegion)
|
||||||
t.Setenv("AWS_ENDPOINT_URL", awsEndpoint)
|
t.Setenv("AWS_ENDPOINT_URL", awsEndpoint)
|
||||||
|
|||||||
Reference in New Issue
Block a user