Files
netbird/proxy/internal/tcp/router_test.go

1742 lines
50 KiB
Go

package tcp
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"math/big"
"net"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types"
)
func TestRouter_HTTPRouting(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, nil, addr)
router.AddRoute("example.com", Route{Type: RouteHTTP})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
_ = router.Serve(ctx, ln)
}()
// Dial in a goroutine. The TLS handshake will block since nothing
// completes it on the HTTP side, but we only care about routing.
go func() {
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
if err != nil {
return
}
// Send a TLS ClientHello manually.
tlsConn := tls.Client(conn, &tls.Config{
ServerName: "example.com",
InsecureSkipVerify: true, //nolint:gosec
})
_ = tlsConn.Handshake()
tlsConn.Close()
}()
// Verify the connection was routed to the HTTP channel.
select {
case conn := <-router.httpCh:
assert.NotNil(t, conn)
conn.Close()
case <-time.After(5 * time.Second):
t.Fatal("no connection received on HTTP channel")
}
}
func TestRouter_TCPRouting(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
// Set up a TLS backend that the relay will connect to.
backendCert := generateSelfSignedCert(t)
backendLn, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{backendCert},
})
require.NoError(t, err)
defer backendLn.Close()
backendAddr := backendLn.Addr().String()
// Accept one connection on the backend, echo data back.
backendReady := make(chan struct{})
go func() {
close(backendReady)
conn, err := backendLn.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 1024)
n, _ := conn.Read(buf)
_, _ = conn.Write(buf[:n])
}()
<-backendReady
dialResolve := func(accountID types.AccountID) (types.DialContextFunc, error) {
return func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}, nil
}
router := NewRouter(logger, dialResolve, addr)
router.AddRoute("tcp.example.com", Route{
Type: RouteTCP,
AccountID: "test-account",
ServiceID: "test-service",
Target: backendAddr,
})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
_ = router.Serve(ctx, ln)
}()
// Connect as a TLS client; the proxy should passthrough to the backend.
clientConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
ServerName: "tcp.example.com",
InsecureSkipVerify: true, //nolint:gosec
})
require.NoError(t, err)
defer clientConn.Close()
testData := []byte("hello through TCP passthrough")
_, err = clientConn.Write(testData)
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := clientConn.Read(buf)
require.NoError(t, err)
assert.Equal(t, testData, buf[:n], "should receive echoed data through TCP passthrough")
}
func TestRouter_UnknownSNIGoesToHTTP(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, nil, addr)
// No routes registered.
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
_ = router.Serve(ctx, ln)
}()
go func() {
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
if err != nil {
return
}
tlsConn := tls.Client(conn, &tls.Config{
ServerName: "unknown.example.com",
InsecureSkipVerify: true, //nolint:gosec
})
_ = tlsConn.Handshake()
tlsConn.Close()
}()
select {
case conn := <-router.httpCh:
assert.NotNil(t, conn)
conn.Close()
case <-time.After(5 * time.Second):
t.Fatal("unknown SNI should be routed to HTTP")
}
}
// TestRouter_NonTLSConnectionDropped verifies that a non-TLS connection
// on the shared port is closed by the router (SNI peek fails to find a
// valid ClientHello, so there is no route match).
func TestRouter_NonTLSConnectionDropped(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
// Register a TLS passthrough route. Non-TLS should NOT match.
dialResolve := func(accountID types.AccountID) (types.DialContextFunc, error) {
return func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}, nil
}
router := NewRouter(logger, dialResolve, addr)
router.AddRoute("tcp.example.com", Route{
Type: RouteTCP,
AccountID: "test-account",
ServiceID: "test-service",
Target: "127.0.0.1:9999",
})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
_ = router.Serve(ctx, ln)
}()
// Send plain HTTP (non-TLS) data.
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
require.NoError(t, err)
defer conn.Close()
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: tcp.example.com\r\n\r\n"))
// Non-TLS traffic on a port with RouteTCP goes to the HTTP channel
// because there's no valid SNI to match. Verify it reaches HTTP.
select {
case httpConn := <-router.httpCh:
assert.NotNil(t, httpConn, "non-TLS connection should fall through to HTTP")
httpConn.Close()
case <-time.After(5 * time.Second):
t.Fatal("non-TLS connection was not routed to HTTP")
}
}
// TestRouter_TLSAndHTTPCoexist verifies that a shared port with both HTTP
// and TLS passthrough routes correctly demuxes based on the SNI hostname.
func TestRouter_TLSAndHTTPCoexist(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
backendCert := generateSelfSignedCert(t)
backendLn, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{backendCert},
})
require.NoError(t, err)
defer backendLn.Close()
// Backend echoes data.
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 1024)
n, _ := conn.Read(buf)
_, _ = conn.Write(buf[:n])
}()
dialResolve := func(accountID types.AccountID) (types.DialContextFunc, error) {
return func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}, nil
}
router := NewRouter(logger, dialResolve, addr)
// HTTP route.
router.AddRoute("app.example.com", Route{Type: RouteHTTP})
// TLS passthrough route.
router.AddRoute("tcp.example.com", Route{
Type: RouteTCP,
AccountID: "test-account",
ServiceID: "test-service",
Target: backendLn.Addr().String(),
})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
_ = router.Serve(ctx, ln)
}()
// 1. TLS connection with SNI "tcp.example.com" → TLS passthrough.
tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
ServerName: "tcp.example.com",
InsecureSkipVerify: true, //nolint:gosec
})
require.NoError(t, err)
testData := []byte("passthrough data")
_, err = tlsConn.Write(testData)
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := tlsConn.Read(buf)
require.NoError(t, err)
assert.Equal(t, testData, buf[:n], "TLS passthrough should relay data")
tlsConn.Close()
// 2. TLS connection with SNI "app.example.com" → HTTP handler.
go func() {
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
if err != nil {
return
}
c := tls.Client(conn, &tls.Config{
ServerName: "app.example.com",
InsecureSkipVerify: true, //nolint:gosec
})
_ = c.Handshake()
c.Close()
}()
select {
case httpConn := <-router.httpCh:
assert.NotNil(t, httpConn, "HTTP SNI should go to HTTP handler")
httpConn.Close()
case <-time.After(5 * time.Second):
t.Fatal("HTTP-route connection was not delivered to HTTP handler")
}
}
func TestRouter_AddRemoveRoute(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, nil, addr)
router.AddRoute("a.example.com", Route{Type: RouteHTTP, ServiceID: "svc-a"})
router.AddRoute("b.example.com", Route{Type: RouteTCP, ServiceID: "svc-b", Target: "10.0.0.1:5432"})
route, ok := router.lookupRoute("a.example.com")
assert.True(t, ok)
assert.Equal(t, RouteHTTP, route.Type)
route, ok = router.lookupRoute("b.example.com")
assert.True(t, ok)
assert.Equal(t, RouteTCP, route.Type)
router.RemoveRoute("a.example.com", "svc-a")
_, ok = router.lookupRoute("a.example.com")
assert.False(t, ok)
}
func TestChanListener_AcceptAndClose(t *testing.T) {
ch := make(chan net.Conn, 1)
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
ln := newChanListener(ch, addr)
assert.Equal(t, addr, ln.Addr())
// Send a connection.
clientConn, serverConn := net.Pipe()
defer clientConn.Close()
defer serverConn.Close()
ch <- serverConn
conn, err := ln.Accept()
require.NoError(t, err)
assert.Equal(t, serverConn, conn)
// Close should cause Accept to return error.
require.NoError(t, ln.Close())
// Double close should be safe.
require.NoError(t, ln.Close())
_, err = ln.Accept()
assert.ErrorIs(t, err, net.ErrClosed)
}
func TestRouter_HTTPPrecedenceGuard(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, nil, addr)
host := SNIHost("app.example.com")
t.Run("http takes precedence over tcp at lookup", func(t *testing.T) {
router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-http"})
router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-tcp", Target: "10.0.0.1:443"})
route, ok := router.lookupRoute(host)
require.True(t, ok)
assert.Equal(t, RouteHTTP, route.Type, "HTTP route must take precedence over TCP")
assert.Equal(t, types.ServiceID("svc-http"), route.ServiceID)
router.RemoveRoute(host, "svc-http")
router.RemoveRoute(host, "svc-tcp")
})
t.Run("tcp becomes active when http is removed", func(t *testing.T) {
router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-http"})
router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-tcp", Target: "10.0.0.1:443"})
router.RemoveRoute(host, "svc-http")
route, ok := router.lookupRoute(host)
require.True(t, ok)
assert.Equal(t, RouteTCP, route.Type, "TCP should take over after HTTP removal")
assert.Equal(t, types.ServiceID("svc-tcp"), route.ServiceID)
router.RemoveRoute(host, "svc-tcp")
})
t.Run("order of add does not matter", func(t *testing.T) {
router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-tcp", Target: "10.0.0.1:443"})
router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-http"})
route, ok := router.lookupRoute(host)
require.True(t, ok)
assert.Equal(t, RouteHTTP, route.Type, "HTTP takes precedence regardless of add order")
router.RemoveRoute(host, "svc-http")
router.RemoveRoute(host, "svc-tcp")
})
t.Run("same service id updates in place", func(t *testing.T) {
router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-1", Target: "10.0.0.1:443"})
router.AddRoute(host, Route{Type: RouteTCP, ServiceID: "svc-1", Target: "10.0.0.2:443"})
route, ok := router.lookupRoute(host)
require.True(t, ok)
assert.Equal(t, "10.0.0.2:443", route.Target, "route should be updated in place")
router.RemoveRoute(host, "svc-1")
_, ok = router.lookupRoute(host)
assert.False(t, ok)
})
t.Run("double remove is safe", func(t *testing.T) {
router.AddRoute(host, Route{Type: RouteHTTP, ServiceID: "svc-1"})
router.RemoveRoute(host, "svc-1")
router.RemoveRoute(host, "svc-1")
_, ok := router.lookupRoute(host)
assert.False(t, ok, "route should be gone after removal")
})
t.Run("remove does not affect other hosts", func(t *testing.T) {
router.AddRoute("a.example.com", Route{Type: RouteHTTP, ServiceID: "svc-a"})
router.AddRoute("b.example.com", Route{Type: RouteTCP, ServiceID: "svc-b", Target: "10.0.0.2:22"})
router.RemoveRoute("a.example.com", "svc-a")
_, ok := router.lookupRoute(SNIHost("a.example.com"))
assert.False(t, ok)
route, ok := router.lookupRoute(SNIHost("b.example.com"))
require.True(t, ok)
assert.Equal(t, RouteTCP, route.Type, "removing one host must not affect another")
router.RemoveRoute("b.example.com", "svc-b")
})
}
func TestRouter_SetRemoveFallback(t *testing.T) {
logger := log.StandardLogger()
router := NewPortRouter(logger, nil)
assert.True(t, router.IsEmpty(), "new port router should be empty")
router.SetFallback(Route{Type: RouteTCP, ServiceID: "svc-fb", Target: "10.0.0.1:5432"})
assert.False(t, router.IsEmpty(), "router with fallback should not be empty")
router.AddRoute("a.example.com", Route{Type: RouteTCP, ServiceID: "svc-a", Target: "10.0.0.2:443"})
assert.False(t, router.IsEmpty())
router.RemoveFallback("svc-fb")
assert.False(t, router.IsEmpty(), "router with SNI route should not be empty")
router.RemoveRoute("a.example.com", "svc-a")
assert.True(t, router.IsEmpty(), "router with no routes and no fallback should be empty")
}
func TestPortRouter_FallbackRelaysData(t *testing.T) {
// Backend echo server.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer backendLn.Close()
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 1024)
n, _ := conn.Read(buf)
_, _ = conn.Write(buf[:n])
}()
dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) {
return func(_ context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}, nil
}
logger := log.StandardLogger()
router := NewPortRouter(logger, dialResolve)
router.SetFallback(Route{
Type: RouteTCP,
AccountID: "test-account",
ServiceID: "test-service",
Target: backendLn.Addr().String(),
})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
// Plain TCP (non-TLS) connection should be relayed via fallback.
// Use exactly 5 bytes. PeekClientHello reads 5 bytes as the TLS
// header, so a single 5-byte write lands as one chunk at the backend.
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
require.NoError(t, err)
defer conn.Close()
testData := []byte("hello")
_, err = conn.Write(testData)
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := conn.Read(buf)
require.NoError(t, err)
assert.Equal(t, testData, buf[:n], "should receive echoed data through fallback relay")
}
func TestPortRouter_FallbackOnUnknownSNI(t *testing.T) {
// Backend TLS echo server.
backendCert := generateSelfSignedCert(t)
backendLn, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{backendCert},
})
require.NoError(t, err)
defer backendLn.Close()
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 1024)
n, _ := conn.Read(buf)
_, _ = conn.Write(buf[:n])
}()
dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) {
return func(_ context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}, nil
}
logger := log.StandardLogger()
router := NewPortRouter(logger, dialResolve)
// Only a fallback, no SNI route for "unknown.example.com".
router.SetFallback(Route{
Type: RouteTCP,
AccountID: "test-account",
ServiceID: "test-service",
Target: backendLn.Addr().String(),
})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
// TLS with unknown SNI → fallback relay to TLS backend.
tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
ServerName: "tcp.example.com",
InsecureSkipVerify: true, //nolint:gosec
})
require.NoError(t, err)
defer tlsConn.Close()
testData := []byte("hello through fallback TLS")
_, err = tlsConn.Write(testData)
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := tlsConn.Read(buf)
require.NoError(t, err)
assert.Equal(t, testData, buf[:n], "unknown SNI should relay through fallback")
}
func TestPortRouter_SNIWinsOverFallback(t *testing.T) {
// Two backend echo servers: one for SNI match, one for fallback.
sniBacked := startEchoTLS(t)
fbBacked := startEchoTLS(t)
dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) {
return func(_ context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}, nil
}
logger := log.StandardLogger()
router := NewPortRouter(logger, dialResolve)
router.AddRoute("tcp.example.com", Route{
Type: RouteTCP,
AccountID: "test-account",
ServiceID: "sni-service",
Target: sniBacked.Addr().String(),
})
router.SetFallback(Route{
Type: RouteTCP,
AccountID: "test-account",
ServiceID: "fb-service",
Target: fbBacked.Addr().String(),
})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
// TLS with matching SNI should go to SNI backend, not fallback.
tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
ServerName: "tcp.example.com",
InsecureSkipVerify: true, //nolint:gosec
})
require.NoError(t, err)
defer tlsConn.Close()
testData := []byte("SNI route data")
_, err = tlsConn.Write(testData)
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := tlsConn.Read(buf)
require.NoError(t, err)
assert.Equal(t, testData, buf[:n], "SNI match should use SNI route, not fallback")
}
func TestPortRouter_NoFallbackNoHTTP_Closes(t *testing.T) {
logger := log.StandardLogger()
router := NewPortRouter(logger, nil)
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
require.NoError(t, err)
defer conn.Close()
_, _ = conn.Write([]byte("hello"))
// Connection should be closed by the router (no fallback, no HTTP).
buf := make([]byte, 1)
_ = conn.SetReadDeadline(time.Now().Add(3 * time.Second))
_, err = conn.Read(buf)
assert.Error(t, err, "connection should be closed when no fallback and no HTTP channel")
}
func TestRouter_FallbackAndHTTPCoexist(t *testing.T) {
// Fallback backend echo server (plain TCP).
fbBackend, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer fbBackend.Close()
go func() {
conn, err := fbBackend.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 1024)
n, _ := conn.Read(buf)
_, _ = conn.Write(buf[:n])
}()
dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) {
return func(_ context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}, nil
}
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, dialResolve, addr)
// HTTP route for known SNI.
router.AddRoute("app.example.com", Route{Type: RouteHTTP})
// Fallback for non-TLS / unknown SNI.
router.SetFallback(Route{
Type: RouteTCP,
AccountID: "test-account",
ServiceID: "fb-service",
Target: fbBackend.Addr().String(),
})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
// 1. TLS with known HTTP SNI → should go to HTTP channel.
go func() {
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
if err != nil {
return
}
c := tls.Client(conn, &tls.Config{
ServerName: "app.example.com",
InsecureSkipVerify: true, //nolint:gosec
})
_ = c.Handshake()
c.Close()
}()
select {
case httpConn := <-router.httpCh:
assert.NotNil(t, httpConn, "known HTTP SNI should go to HTTP channel")
httpConn.Close()
case <-time.After(5 * time.Second):
t.Fatal("HTTP-route connection was not delivered to HTTP handler")
}
// 2. Plain TCP (non-TLS) → should go to fallback, not HTTP.
// Use exactly 5 bytes to match PeekClientHello header size.
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
require.NoError(t, err)
defer conn.Close()
testData := []byte("plain")
_, err = conn.Write(testData)
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := conn.Read(buf)
require.NoError(t, err)
assert.Equal(t, testData, buf[:n], "non-TLS should be relayed via fallback, not HTTP")
}
// startEchoTLS starts a TLS echo server and returns the listener.
func startEchoTLS(t *testing.T) net.Listener {
t.Helper()
cert := generateSelfSignedCert(t)
ln, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{cert},
})
require.NoError(t, err)
t.Cleanup(func() { ln.Close() })
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 1024)
for {
n, err := conn.Read(buf)
if err != nil {
return
}
if _, err := conn.Write(buf[:n]); err != nil {
return
}
}
}()
return ln
}
func generateSelfSignedCert(t *testing.T) tls.Certificate {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
DNSNames: []string{"tcp.example.com"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
require.NoError(t, err)
return tls.Certificate{
Certificate: [][]byte{certDER},
PrivateKey: key,
}
}
func TestRouter_DrainWaitsForRelays(t *testing.T) {
logger := log.StandardLogger()
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer backendLn.Close()
// Accept connections: echo first message, then hold open until told to close.
closeBackend := make(chan struct{})
go func() {
for {
conn, err := backendLn.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
buf := make([]byte, 1024)
n, _ := c.Read(buf)
_, _ = c.Write(buf[:n])
<-closeBackend
}(conn)
}
}()
dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) {
return (&net.Dialer{}).DialContext, nil
}
router := NewPortRouter(logger, dialResolve)
router.SetFallback(Route{
Type: RouteTCP,
Target: backendLn.Addr().String(),
})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
serveDone := make(chan struct{})
go func() {
_ = router.Serve(ctx, ln)
close(serveDone)
}()
// Open a relay connection (non-TLS, hits fallback).
conn, err := net.Dial("tcp", ln.Addr().String())
require.NoError(t, err)
_, _ = conn.Write([]byte("hello"))
// Wait for the echo to confirm the relay is fully established.
buf := make([]byte, 16)
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err := conn.Read(buf)
require.NoError(t, err)
assert.Equal(t, "hello", string(buf[:n]))
_ = conn.SetReadDeadline(time.Time{})
// Drain with a short timeout should fail because the relay is still active.
assert.False(t, router.Drain(50*time.Millisecond), "drain should timeout with active relay")
// Close backend connections so relays finish.
close(closeBackend)
_ = conn.Close()
// Drain should now complete quickly.
assert.True(t, router.Drain(2*time.Second), "drain should succeed after relays end")
cancel()
<-serveDone
}
func TestRouter_DrainEmptyReturnsImmediately(t *testing.T) {
logger := log.StandardLogger()
router := NewPortRouter(logger, nil)
start := time.Now()
ok := router.Drain(5 * time.Second)
elapsed := time.Since(start)
assert.True(t, ok)
assert.Less(t, elapsed, 100*time.Millisecond, "drain with no relays should return immediately")
}
// TestRemoveRoute_KillsActiveRelays verifies that removing a route
// immediately kills active relay connections for that service.
func TestRemoveRoute_KillsActiveRelays(t *testing.T) {
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer backendLn.Close()
// Backend echoes first message, then holds connection open.
go func() {
for {
conn, err := backendLn.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
buf := make([]byte, 1024)
n, _ := c.Read(buf)
_, _ = c.Write(buf[:n])
// Hold the connection open.
for {
if _, err := c.Read(buf); err != nil {
return
}
}
}(conn)
}
}()
dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) {
return (&net.Dialer{}).DialContext, nil
}
logger := log.StandardLogger()
router := NewPortRouter(logger, dialResolve)
router.SetFallback(Route{
Type: RouteTCP,
ServiceID: "svc-1",
Target: backendLn.Addr().String(),
})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
// Establish a relay connection.
conn, err := net.Dial("tcp", ln.Addr().String())
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte("hello"))
require.NoError(t, err)
// Wait for echo to confirm relay is established.
buf := make([]byte, 16)
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err := conn.Read(buf)
require.NoError(t, err)
assert.Equal(t, "hello", string(buf[:n]))
_ = conn.SetReadDeadline(time.Time{})
// Remove the fallback: should kill the active relay.
router.RemoveFallback("svc-1")
// The client connection should see an error (server closed).
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(buf)
assert.Error(t, err, "connection should be killed after service removal")
}
// TestRemoveRoute_KillsSNIRelays verifies that removing an SNI route
// kills its active relays without affecting other services.
func TestRemoveRoute_KillsSNIRelays(t *testing.T) {
backend := startEchoTLS(t)
defer backend.Close()
dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) {
return (&net.Dialer{}).DialContext, nil
}
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, dialResolve, addr)
router.AddRoute("tls.example.com", Route{
Type: RouteTCP,
ServiceID: "svc-tls",
Target: backend.Addr().String(),
})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
// Establish a TLS relay.
tlsConn, err := tls.DialWithDialer(
&net.Dialer{Timeout: 2 * time.Second},
"tcp", ln.Addr().String(),
&tls.Config{ServerName: "tls.example.com", InsecureSkipVerify: true},
)
require.NoError(t, err)
defer tlsConn.Close()
_, err = tlsConn.Write([]byte("ping"))
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := tlsConn.Read(buf)
require.NoError(t, err)
assert.Equal(t, "ping", string(buf[:n]))
// Remove the route: active relay should die.
router.RemoveRoute("tls.example.com", "svc-tls")
_ = tlsConn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = tlsConn.Read(buf)
assert.Error(t, err, "TLS relay should be killed after route removal")
}
// TestPortRouter_SNIAndTCPFallbackCoexist verifies that a single port can
// serve both SNI-routed TLS passthrough and plain TCP fallback simultaneously.
func TestPortRouter_SNIAndTCPFallbackCoexist(t *testing.T) {
sniBackend := startEchoTLS(t)
fbBackend := startEchoPlain(t)
dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) {
return func(_ context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}, nil
}
logger := log.StandardLogger()
router := NewPortRouter(logger, dialResolve)
// SNI route for a specific domain.
router.AddRoute("tcp.example.com", Route{
Type: RouteTCP,
AccountID: "acct-1",
ServiceID: "svc-sni",
Target: sniBackend.Addr().String(),
})
// TCP fallback for everything else.
router.SetFallback(Route{
Type: RouteTCP,
AccountID: "acct-2",
ServiceID: "svc-fb",
Target: fbBackend.Addr().String(),
})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
// 1. TLS with matching SNI → goes to SNI backend.
tlsConn, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
ServerName: "tcp.example.com",
InsecureSkipVerify: true, //nolint:gosec
})
require.NoError(t, err)
_, err = tlsConn.Write([]byte("sni-data"))
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := tlsConn.Read(buf)
require.NoError(t, err)
assert.Equal(t, "sni-data", string(buf[:n]), "SNI match → SNI backend")
tlsConn.Close()
// 2. Plain TCP (no TLS) → goes to fallback.
tcpConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
require.NoError(t, err)
_, err = tcpConn.Write([]byte("plain"))
require.NoError(t, err)
n, err = tcpConn.Read(buf)
require.NoError(t, err)
assert.Equal(t, "plain", string(buf[:n]), "plain TCP → fallback backend")
tcpConn.Close()
// 3. TLS with unknown SNI → also goes to fallback.
unknownBackend := startEchoTLS(t)
router.SetFallback(Route{
Type: RouteTCP,
AccountID: "acct-2",
ServiceID: "svc-fb",
Target: unknownBackend.Addr().String(),
})
unknownTLS, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
ServerName: "unknown.example.com",
InsecureSkipVerify: true, //nolint:gosec
})
require.NoError(t, err)
_, err = unknownTLS.Write([]byte("unknown-sni"))
require.NoError(t, err)
n, err = unknownTLS.Read(buf)
require.NoError(t, err)
assert.Equal(t, "unknown-sni", string(buf[:n]), "unknown SNI → fallback backend")
unknownTLS.Close()
}
// TestPortRouter_UpdateRouteSwapsSNI verifies that updating a route
// (remove + add with different target) correctly routes to the new backend.
func TestPortRouter_UpdateRouteSwapsSNI(t *testing.T) {
backend1 := startEchoTLS(t)
backend2 := startEchoTLS(t)
dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) {
return func(_ context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}, nil
}
logger := log.StandardLogger()
router := NewPortRouter(logger, dialResolve)
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
// Initial route → backend1.
router.AddRoute("db.example.com", Route{
Type: RouteTCP,
ServiceID: "svc-db",
Target: backend1.Addr().String(),
})
conn1, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
ServerName: "db.example.com",
InsecureSkipVerify: true, //nolint:gosec
})
require.NoError(t, err)
_, err = conn1.Write([]byte("v1"))
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := conn1.Read(buf)
require.NoError(t, err)
assert.Equal(t, "v1", string(buf[:n]))
conn1.Close()
// Update: remove old route, add new → backend2.
router.RemoveRoute("db.example.com", "svc-db")
router.AddRoute("db.example.com", Route{
Type: RouteTCP,
ServiceID: "svc-db",
Target: backend2.Addr().String(),
})
conn2, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
ServerName: "db.example.com",
InsecureSkipVerify: true, //nolint:gosec
})
require.NoError(t, err)
_, err = conn2.Write([]byte("v2"))
require.NoError(t, err)
n, err = conn2.Read(buf)
require.NoError(t, err)
assert.Equal(t, "v2", string(buf[:n]))
conn2.Close()
}
// TestPortRouter_RemoveSNIFallsThrough verifies that after removing an
// SNI route, connections for that domain fall through to the fallback.
func TestPortRouter_RemoveSNIFallsThrough(t *testing.T) {
sniBackend := startEchoTLS(t)
fbBackend := startEchoTLS(t)
dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) {
return func(_ context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}, nil
}
logger := log.StandardLogger()
router := NewPortRouter(logger, dialResolve)
router.AddRoute("db.example.com", Route{
Type: RouteTCP,
ServiceID: "svc-db",
Target: sniBackend.Addr().String(),
})
router.SetFallback(Route{
Type: RouteTCP,
Target: fbBackend.Addr().String(),
})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
// Before removal: SNI matches → sniBackend.
conn1, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
ServerName: "db.example.com",
InsecureSkipVerify: true, //nolint:gosec
})
require.NoError(t, err)
_, err = conn1.Write([]byte("before"))
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := conn1.Read(buf)
require.NoError(t, err)
assert.Equal(t, "before", string(buf[:n]))
conn1.Close()
// Remove SNI route. Should fall through to fallback.
router.RemoveRoute("db.example.com", "svc-db")
conn2, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
ServerName: "db.example.com",
InsecureSkipVerify: true, //nolint:gosec
})
require.NoError(t, err)
_, err = conn2.Write([]byte("after"))
require.NoError(t, err)
n, err = conn2.Read(buf)
require.NoError(t, err)
assert.Equal(t, "after", string(buf[:n]), "after removal, should reach fallback")
conn2.Close()
}
// TestPortRouter_RemoveFallbackCloses verifies that after removing the
// fallback, non-matching connections are closed.
func TestPortRouter_RemoveFallbackCloses(t *testing.T) {
fbBackend := startEchoPlain(t)
dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) {
return func(_ context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, address)
}, nil
}
logger := log.StandardLogger()
router := NewPortRouter(logger, dialResolve)
router.SetFallback(Route{
Type: RouteTCP,
ServiceID: "svc-fb",
Target: fbBackend.Addr().String(),
})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
// With fallback: plain TCP works.
conn1, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
require.NoError(t, err)
_, err = conn1.Write([]byte("hello"))
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := conn1.Read(buf)
require.NoError(t, err)
assert.Equal(t, "hello", string(buf[:n]))
conn1.Close()
// Remove fallback.
router.RemoveFallback("svc-fb")
// Without fallback on a port router (no HTTP channel): connection should be closed.
conn2, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
require.NoError(t, err)
defer conn2.Close()
_, _ = conn2.Write([]byte("bye"))
_ = conn2.SetReadDeadline(time.Now().Add(3 * time.Second))
_, err = conn2.Read(buf)
assert.Error(t, err, "without fallback, connection should be closed")
}
// TestPortRouter_HTTPToTLSTransition verifies that switching a service from
// HTTP-only to TLS-only via remove+add doesn't orphan the old HTTP route.
func TestPortRouter_HTTPToTLSTransition(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
tlsBackend := startEchoTLS(t)
dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) {
return (&net.Dialer{}).DialContext, nil
}
router := NewRouter(logger, dialResolve, addr)
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
// Phase 1: HTTP-only. SNI connections go to HTTP channel.
router.AddRoute("app.example.com", Route{Type: RouteHTTP, AccountID: "acct-1", ServiceID: "svc-1"})
httpConn := router.HTTPListener()
connDone := make(chan struct{})
go func() {
defer close(connDone)
c, err := httpConn.Accept()
if err == nil {
c.Close()
}
}()
tlsConn, err := tls.DialWithDialer(
&net.Dialer{Timeout: 2 * time.Second},
"tcp", ln.Addr().String(),
&tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true},
)
if err == nil {
tlsConn.Close()
}
select {
case <-connDone:
case <-time.After(2 * time.Second):
t.Fatal("HTTP listener did not receive connection for HTTP-only route")
}
// Phase 2: Simulate update to TLS-only (removeMapping + addMapping).
router.RemoveRoute("app.example.com", "svc-1")
router.AddRoute("app.example.com", Route{
Type: RouteTCP,
AccountID: "acct-1",
ServiceID: "svc-1",
Target: tlsBackend.Addr().String(),
})
tlsConn2, err := tls.DialWithDialer(
&net.Dialer{Timeout: 2 * time.Second},
"tcp", ln.Addr().String(),
&tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true},
)
require.NoError(t, err, "TLS connection should succeed after HTTP→TLS transition")
defer tlsConn2.Close()
_, err = tlsConn2.Write([]byte("hello-tls"))
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := tlsConn2.Read(buf)
require.NoError(t, err)
assert.Equal(t, "hello-tls", string(buf[:n]), "data should relay to TLS backend")
}
// TestPortRouter_TLSToHTTPTransition verifies that switching a service from
// TLS-only to HTTP-only via remove+add doesn't orphan the old TLS route.
func TestPortRouter_TLSToHTTPTransition(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
tlsBackend := startEchoTLS(t)
dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) {
return (&net.Dialer{}).DialContext, nil
}
router := NewRouter(logger, dialResolve, addr)
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
// Phase 1: TLS-only. Route relays to backend.
router.AddRoute("app.example.com", Route{
Type: RouteTCP,
AccountID: "acct-1",
ServiceID: "svc-1",
Target: tlsBackend.Addr().String(),
})
tlsConn, err := tls.DialWithDialer(
&net.Dialer{Timeout: 2 * time.Second},
"tcp", ln.Addr().String(),
&tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true},
)
require.NoError(t, err, "TLS relay should work before transition")
_, err = tlsConn.Write([]byte("tls-data"))
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := tlsConn.Read(buf)
require.NoError(t, err)
assert.Equal(t, "tls-data", string(buf[:n]))
tlsConn.Close()
// Phase 2: Simulate update to HTTP-only (removeMapping + addMapping).
router.RemoveRoute("app.example.com", "svc-1")
router.AddRoute("app.example.com", Route{Type: RouteHTTP, AccountID: "acct-1", ServiceID: "svc-1"})
// TLS connection should now go to the HTTP listener, NOT to the old TLS backend.
httpConn := router.HTTPListener()
connDone := make(chan struct{})
go func() {
defer close(connDone)
c, err := httpConn.Accept()
if err == nil {
c.Close()
}
}()
tlsConn2, err := tls.DialWithDialer(
&net.Dialer{Timeout: 2 * time.Second},
"tcp", ln.Addr().String(),
&tls.Config{ServerName: "app.example.com", InsecureSkipVerify: true},
)
if err == nil {
tlsConn2.Close()
}
select {
case <-connDone:
case <-time.After(2 * time.Second):
t.Fatal("HTTP listener should receive connection after TLS→HTTP transition")
}
}
// TestPortRouter_MultiDomainSamePort verifies that two TLS services sharing
// the same port router are independently routable and removable.
func TestPortRouter_MultiDomainSamePort(t *testing.T) {
logger := log.StandardLogger()
backend1 := startEchoTLSMulti(t)
backend2 := startEchoTLSMulti(t)
dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) {
return (&net.Dialer{}).DialContext, nil
}
router := NewPortRouter(logger, dialResolve)
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
router.AddRoute("svc1.example.com", Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "svc-1", Target: backend1.Addr().String()})
router.AddRoute("svc2.example.com", Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "svc-2", Target: backend2.Addr().String()})
assert.False(t, router.IsEmpty())
// Both domains route independently.
for _, tc := range []struct {
sni string
data string
}{
{"svc1.example.com", "hello-svc1"},
{"svc2.example.com", "hello-svc2"},
} {
conn, err := tls.DialWithDialer(
&net.Dialer{Timeout: 2 * time.Second},
"tcp", ln.Addr().String(),
&tls.Config{ServerName: tc.sni, InsecureSkipVerify: true},
)
require.NoError(t, err, "dial %s", tc.sni)
_, err = conn.Write([]byte(tc.data))
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := conn.Read(buf)
require.NoError(t, err)
assert.Equal(t, tc.data, string(buf[:n]))
conn.Close()
}
// Remove svc1. Router should NOT be empty (svc2 still present).
router.RemoveRoute("svc1.example.com", "svc-1")
assert.False(t, router.IsEmpty(), "router should not be empty with one route remaining")
// svc2 still works.
conn2, err := tls.DialWithDialer(
&net.Dialer{Timeout: 2 * time.Second},
"tcp", ln.Addr().String(),
&tls.Config{ServerName: "svc2.example.com", InsecureSkipVerify: true},
)
require.NoError(t, err)
_, err = conn2.Write([]byte("still-alive"))
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := conn2.Read(buf)
require.NoError(t, err)
assert.Equal(t, "still-alive", string(buf[:n]))
conn2.Close()
// Remove svc2. Router is now empty.
router.RemoveRoute("svc2.example.com", "svc-2")
assert.True(t, router.IsEmpty(), "router should be empty after removing all routes")
}
// TestPortRouter_SNIAndFallbackLifecycle verifies the full lifecycle of SNI
// routes and TCP fallback coexisting on the same port router, including the
// ordering of add/remove operations.
func TestPortRouter_SNIAndFallbackLifecycle(t *testing.T) {
logger := log.StandardLogger()
sniBackend := startEchoTLS(t)
fallbackBackend := startEchoPlain(t)
dialResolve := func(_ types.AccountID) (types.DialContextFunc, error) {
return (&net.Dialer{}).DialContext, nil
}
router := NewPortRouter(logger, dialResolve)
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
// Step 1: Add fallback first (port mapping), then SNI route (TLS service).
router.SetFallback(Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "pm-1", Target: fallbackBackend.Addr().String()})
router.AddRoute("tls.example.com", Route{Type: RouteTCP, AccountID: "acct-1", ServiceID: "svc-1", Target: sniBackend.Addr().String()})
assert.False(t, router.IsEmpty())
// SNI traffic goes to TLS backend.
tlsConn, err := tls.DialWithDialer(
&net.Dialer{Timeout: 2 * time.Second},
"tcp", ln.Addr().String(),
&tls.Config{ServerName: "tls.example.com", InsecureSkipVerify: true},
)
require.NoError(t, err)
_, err = tlsConn.Write([]byte("sni-traffic"))
require.NoError(t, err)
buf := make([]byte, 1024)
n, err := tlsConn.Read(buf)
require.NoError(t, err)
assert.Equal(t, "sni-traffic", string(buf[:n]))
tlsConn.Close()
// Plain TCP goes to fallback.
plainConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
require.NoError(t, err)
_, err = plainConn.Write([]byte("plain"))
require.NoError(t, err)
n, err = plainConn.Read(buf)
require.NoError(t, err)
assert.Equal(t, "plain", string(buf[:n]))
plainConn.Close()
// Step 2: Remove SNI route. Fallback still works, router not empty.
router.RemoveRoute("tls.example.com", "svc-1")
assert.False(t, router.IsEmpty(), "fallback still present")
plainConn2, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
require.NoError(t, err)
// Must send >= 5 bytes so the SNI peek completes immediately
// without waiting for the 5-second peek timeout.
_, err = plainConn2.Write([]byte("after"))
require.NoError(t, err)
n, err = plainConn2.Read(buf)
require.NoError(t, err)
assert.Equal(t, "after", string(buf[:n]))
plainConn2.Close()
// Step 3: Remove fallback. Router is now empty.
router.RemoveFallback("pm-1")
assert.True(t, router.IsEmpty())
}
// TestPortRouter_IsEmptyTransitions verifies IsEmpty reflects correct state
// through all add/remove operations.
func TestPortRouter_IsEmptyTransitions(t *testing.T) {
logger := log.StandardLogger()
router := NewPortRouter(logger, nil)
assert.True(t, router.IsEmpty(), "new router")
router.AddRoute("a.com", Route{Type: RouteTCP, ServiceID: "svc-a"})
assert.False(t, router.IsEmpty(), "after adding route")
router.SetFallback(Route{Type: RouteTCP, ServiceID: "svc-fb1"})
assert.False(t, router.IsEmpty(), "route + fallback")
router.RemoveRoute("a.com", "svc-a")
assert.False(t, router.IsEmpty(), "fallback only")
router.RemoveFallback("svc-fb1")
assert.True(t, router.IsEmpty(), "all removed")
// Reverse order: fallback first, then route.
router.SetFallback(Route{Type: RouteTCP, ServiceID: "svc-fb2"})
assert.False(t, router.IsEmpty())
router.AddRoute("b.com", Route{Type: RouteTCP, ServiceID: "svc-b"})
assert.False(t, router.IsEmpty())
router.RemoveFallback("svc-fb2")
assert.False(t, router.IsEmpty(), "route still present")
router.RemoveRoute("b.com", "svc-b")
assert.True(t, router.IsEmpty(), "fully empty again")
}
// startEchoTLSMulti starts a TLS echo server that accepts multiple connections.
func startEchoTLSMulti(t *testing.T) net.Listener {
t.Helper()
cert := generateSelfSignedCert(t)
ln, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{cert},
})
require.NoError(t, err)
t.Cleanup(func() { ln.Close() })
go func() {
for {
conn, err := ln.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
buf := make([]byte, 1024)
n, _ := c.Read(buf)
_, _ = c.Write(buf[:n])
}(conn)
}
}()
return ln
}
// startEchoPlain starts a plain TCP echo server that reads until newline
// or connection close, then echoes the received data.
func startEchoPlain(t *testing.T) net.Listener {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() { ln.Close() })
go func() {
for {
conn, err := ln.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
// Set a read deadline so we don't block forever waiting for more data.
_ = c.SetReadDeadline(time.Now().Add(2 * time.Second))
buf := make([]byte, 1024)
n, _ := c.Read(buf)
_, _ = c.Write(buf[:n])
}(conn)
}
}()
return ln
}
// fakeAddr implements net.Addr with a custom string representation.
type fakeAddr string
func (f fakeAddr) Network() string { return "tcp" }
func (f fakeAddr) String() string { return string(f) }
// fakeConn is a minimal net.Conn with a controllable RemoteAddr.
type fakeConn struct {
net.Conn
remote net.Addr
}
func (f *fakeConn) RemoteAddr() net.Addr { return f.remote }
func TestCheckRestrictions_UnparseableAddress(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
route := Route{Filter: filter}
conn := &fakeConn{remote: fakeAddr("not-an-ip")}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(conn, route), "unparsable address must be denied")
}
func TestCheckRestrictions_NilRemoteAddr(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
route := Route{Filter: filter}
conn := &fakeConn{remote: nil}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(conn, route), "nil remote address must be denied")
}
func TestCheckRestrictions_AllowedAndDenied(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
route := Route{Filter: filter}
allowed := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(10, 1, 2, 3), Port: 1234}}
assert.Equal(t, restrict.Allow, router.checkRestrictions(allowed, route), "10.1.2.3 in allowlist")
denied := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(192, 168, 1, 1), Port: 1234}}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(denied, route), "192.168.1.1 not in allowlist")
}
func TestCheckRestrictions_NilFilter(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
route := Route{Filter: nil}
conn := &fakeConn{remote: fakeAddr("not-an-ip")}
assert.Equal(t, restrict.Allow, router.checkRestrictions(conn, route), "nil filter should allow everything")
}
func TestCheckRestrictions_IPv4MappedIPv6(t *testing.T) {
router := NewPortRouter(log.StandardLogger(), nil)
filter := restrict.ParseFilter([]string{"10.0.0.0/8"}, nil, nil, nil)
route := Route{Filter: filter}
// net.IPv4() returns a 16-byte v4-in-v6 representation internally.
// The restriction check must Unmap it to match the v4 CIDR.
conn := &fakeConn{remote: &net.TCPAddr{IP: net.IPv4(10, 1, 2, 3), Port: 5678}}
assert.Equal(t, restrict.Allow, router.checkRestrictions(conn, route), "v4-in-v6 TCPAddr must match v4 CIDR")
// Explicitly v4-mapped-v6 address string.
conn6 := &fakeConn{remote: fakeAddr("[::ffff:10.1.2.3]:5678")}
assert.Equal(t, restrict.Allow, router.checkRestrictions(conn6, route), "::ffff:10.1.2.3 must match v4 CIDR")
connOutside := &fakeConn{remote: fakeAddr("[::ffff:192.168.1.1]:5678")}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(connOutside, route), "::ffff:192.168.1.1 not in v4 CIDR")
}