Add proxy protocol

This commit is contained in:
Owen
2025-08-26 22:26:01 -07:00
parent 04361242fe
commit 7040a9436e
4 changed files with 163 additions and 2 deletions

View File

@@ -11,6 +11,7 @@ import (
"log"
"net"
"net/http"
"strings"
"sync"
"time"
@@ -42,6 +43,7 @@ type SNIProxy struct {
localProxyPort int
remoteConfigURL string
publicKey string
proxyProtocol bool // Enable PROXY protocol v1
// New fields for fast local SNI lookup
localSNIs map[string]struct{}
@@ -73,8 +75,63 @@ func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
// buildProxyProtocolHeader creates a PROXY protocol v1 header
func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
clientTCP, ok := clientAddr.(*net.TCPAddr)
if !ok {
// Fallback for unknown address types
return "PROXY UNKNOWN\r\n"
}
targetTCP, ok := targetAddr.(*net.TCPAddr)
if !ok {
// Fallback for unknown address types
return "PROXY UNKNOWN\r\n"
}
// Determine protocol family based on client IP and normalize target IP accordingly
var protocol string
var targetIP string
if clientTCP.IP.To4() != nil {
// Client is IPv4, use TCP4 protocol
protocol = "TCP4"
if targetTCP.IP.To4() != nil {
// Target is also IPv4, use as-is
targetIP = targetTCP.IP.String()
} else {
// Target is IPv6, but we need IPv4 for consistent protocol family
// Use the IPv4 loopback if target is IPv6 loopback, otherwise use 127.0.0.1
if targetTCP.IP.IsLoopback() {
targetIP = "127.0.0.1"
} else {
// For non-loopback IPv6 targets, we could try to extract embedded IPv4
// or fall back to a sensible IPv4 address based on the target
targetIP = "127.0.0.1" // Safe fallback
}
}
} else {
// Client is IPv6, use TCP6 protocol
protocol = "TCP6"
if targetTCP.IP.To4() != nil {
// Target is IPv4, convert to IPv6 representation
targetIP = "::ffff:" + targetTCP.IP.String()
} else {
// Target is also IPv6, use as-is
targetIP = targetTCP.IP.String()
}
}
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
protocol,
clientTCP.IP.String(),
targetIP,
clientTCP.Port,
targetTCP.Port)
}
// NewSNIProxy creates a new SNI proxy instance
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string) (*SNIProxy, error) {
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool) (*SNIProxy, error) {
ctx, cancel := context.WithCancel(context.Background())
// Create local overrides map
@@ -94,6 +151,7 @@ func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, lo
localProxyPort: localProxyPort,
remoteConfigURL: remoteConfigURL,
publicKey: publicKey,
proxyProtocol: proxyProtocol,
localSNIs: make(map[string]struct{}),
localOverrides: overridesMap,
activeTunnels: make(map[string]*activeTunnel),
@@ -265,6 +323,17 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort)
// Send PROXY protocol header if enabled
if p.proxyProtocol {
proxyHeader := buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr())
logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader))
if _, err := targetConn.Write([]byte(proxyHeader)); err != nil {
logger.Debug("Failed to send PROXY protocol header: %v", err)
return
}
}
// Track this tunnel by SNI
p.activeTunnelsLock.Lock()
tunnel, ok := p.activeTunnels[hostname]

78
proxy/proxy_test.go Normal file
View File

@@ -0,0 +1,78 @@
package proxy
import (
"net"
"testing"
)
func TestBuildProxyProtocolHeader(t *testing.T) {
tests := []struct {
name string
clientAddr string
targetAddr string
expected string
}{
{
name: "IPv4 client and target",
clientAddr: "192.168.1.100:12345",
targetAddr: "10.0.0.1:443",
expected: "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n",
},
{
name: "IPv6 client and target",
clientAddr: "[2001:db8::1]:12345",
targetAddr: "[2001:db8::2]:443",
expected: "PROXY TCP6 2001:db8::1 2001:db8::2 12345 443\r\n",
},
{
name: "IPv4 client with IPv6 loopback target",
clientAddr: "192.168.1.100:12345",
targetAddr: "[::1]:443",
expected: "PROXY TCP4 192.168.1.100 127.0.0.1 12345 443\r\n",
},
{
name: "IPv4 client with IPv6 target",
clientAddr: "192.168.1.100:12345",
targetAddr: "[2001:db8::2]:443",
expected: "PROXY TCP4 192.168.1.100 127.0.0.1 12345 443\r\n",
},
{
name: "IPv6 client with IPv4 target",
clientAddr: "[2001:db8::1]:12345",
targetAddr: "10.0.0.1:443",
expected: "PROXY TCP6 2001:db8::1 ::ffff:10.0.0.1 12345 443\r\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientTCP, err := net.ResolveTCPAddr("tcp", tt.clientAddr)
if err != nil {
t.Fatalf("Failed to resolve client address: %v", err)
}
targetTCP, err := net.ResolveTCPAddr("tcp", tt.targetAddr)
if err != nil {
t.Fatalf("Failed to resolve target address: %v", err)
}
result := buildProxyProtocolHeader(clientTCP, targetTCP)
if result != tt.expected {
t.Errorf("Expected %q, got %q", tt.expected, result)
}
})
}
}
func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) {
// Test with non-TCP address type
clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345}
targetAddr := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443}
result := buildProxyProtocolHeader(clientAddr, targetAddr)
expected := "PROXY UNKNOWN\r\n"
if result != expected {
t.Errorf("Expected %q, got %q", expected, result)
}
}