From 7040a9436e644d808be310ecaa28e03553068e48 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 26 Aug 2025 22:26:01 -0700 Subject: [PATCH] Add proxy protocol --- README.md | 5 +++ main.go | 11 ++++++- proxy/proxy.go | 71 ++++++++++++++++++++++++++++++++++++++++- proxy/proxy_test.go | 78 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 163 insertions(+), 2 deletions(-) create mode 100644 proxy/proxy_test.go diff --git a/README.md b/README.md index ed5db34..2dc9f64 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,9 @@ Gerbil includes an SNI (Server Name Indication) proxy that enables intelligent r - Otherwise, the proxy queries Pangolin's routing API to determine which node should handle the traffic - Supports caching of routing decisions to improve performance - Handles connection pooling and graceful shutdown +- Optional PROXY protocol v1 support to preserve original client IP addresses when forwarding to downstream proxies (HAProxy, Nginx, etc.) + +The PROXY protocol allows downstream proxies to know the real client IP address instead of seeing the SNI proxy's IP. When enabled with `--proxy-protocol`, the SNI proxy will prepend a PROXY protocol header to each connection containing the original client's IP and port information. In single node (self hosted) Pangolin deployments this can be bypassed by using port 443:443 to route to Traefik instead of the SNI proxy at 8443. @@ -56,6 +59,7 @@ Note: You must use either `config` or `remoteConfig` to configure WireGuard. - `local-proxy` (optional): Address for local proxy when routing local traffic. Default: `localhost` - `local-proxy-port` (optional): Port for local proxy when routing local traffic. Default: `443` - `local-overrides` (optional): Comma-separated list of domain names that should always be routed to the local proxy +- `proxy-protocol` (optional): Enable PROXY protocol v1 for preserving client IP addresses when forwarding to downstream proxies. Default: `false` ## Environment Variables @@ -74,6 +78,7 @@ All CLI arguments can also be provided via environment variables: - `LOCAL_PROXY`: Address for local proxy when routing local traffic - `LOCAL_PROXY_PORT`: Port for local proxy when routing local traffic - `LOCAL_OVERRIDES`: Comma-separated list of domain names that should always be routed to the local proxy +- `PROXY_PROTOCOL`: Enable PROXY protocol v1 for preserving client IP addresses (true/false) Example: diff --git a/main.go b/main.go index 37350ce..344433f 100644 --- a/main.go +++ b/main.go @@ -121,6 +121,7 @@ func main() { localProxyAddr string localProxyPort int localOverridesStr string + proxyProtocol bool ) interfaceName = os.Getenv("INTERFACE") @@ -137,6 +138,7 @@ func main() { localProxyAddr = os.Getenv("LOCAL_PROXY") localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT") localOverridesStr = os.Getenv("LOCAL_OVERRIDES") + proxyProtocolStr := os.Getenv("PROXY_PROTOCOL") if interfaceName == "" { flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") @@ -196,6 +198,13 @@ func main() { flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy") } + if proxyProtocolStr != "" { + proxyProtocol = strings.ToLower(proxyProtocolStr) == "true" + } + if proxyProtocolStr == "" { + flag.BoolVar(&proxyProtocol, "proxy-protocol", true, "Enable PROXY protocol v1 for preserving client IP") + } + flag.Parse() logger.Init() @@ -314,7 +323,7 @@ func main() { logger.Info("Local overrides configured: %v", localOverrides) } - proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides) + proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides, proxyProtocol) if err != nil { logger.Fatal("Failed to create proxy: %v", err) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 7edd92e..7c7b078 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -11,6 +11,7 @@ import ( "log" "net" "net/http" + "strings" "sync" "time" @@ -42,6 +43,7 @@ type SNIProxy struct { localProxyPort int remoteConfigURL string publicKey string + proxyProtocol bool // Enable PROXY protocol v1 // New fields for fast local SNI lookup localSNIs map[string]struct{} @@ -73,8 +75,63 @@ func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } +// buildProxyProtocolHeader creates a PROXY protocol v1 header +func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string { + clientTCP, ok := clientAddr.(*net.TCPAddr) + if !ok { + // Fallback for unknown address types + return "PROXY UNKNOWN\r\n" + } + + targetTCP, ok := targetAddr.(*net.TCPAddr) + if !ok { + // Fallback for unknown address types + return "PROXY UNKNOWN\r\n" + } + + // Determine protocol family based on client IP and normalize target IP accordingly + var protocol string + var targetIP string + + if clientTCP.IP.To4() != nil { + // Client is IPv4, use TCP4 protocol + protocol = "TCP4" + if targetTCP.IP.To4() != nil { + // Target is also IPv4, use as-is + targetIP = targetTCP.IP.String() + } else { + // Target is IPv6, but we need IPv4 for consistent protocol family + // Use the IPv4 loopback if target is IPv6 loopback, otherwise use 127.0.0.1 + if targetTCP.IP.IsLoopback() { + targetIP = "127.0.0.1" + } else { + // For non-loopback IPv6 targets, we could try to extract embedded IPv4 + // or fall back to a sensible IPv4 address based on the target + targetIP = "127.0.0.1" // Safe fallback + } + } + } else { + // Client is IPv6, use TCP6 protocol + protocol = "TCP6" + if targetTCP.IP.To4() != nil { + // Target is IPv4, convert to IPv6 representation + targetIP = "::ffff:" + targetTCP.IP.String() + } else { + // Target is also IPv6, use as-is + targetIP = targetTCP.IP.String() + } + } + + return fmt.Sprintf("PROXY %s %s %s %d %d\r\n", + protocol, + clientTCP.IP.String(), + targetIP, + clientTCP.Port, + targetTCP.Port) +} + // NewSNIProxy creates a new SNI proxy instance -func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string) (*SNIProxy, error) { +func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool) (*SNIProxy, error) { ctx, cancel := context.WithCancel(context.Background()) // Create local overrides map @@ -94,6 +151,7 @@ func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, lo localProxyPort: localProxyPort, remoteConfigURL: remoteConfigURL, publicKey: publicKey, + proxyProtocol: proxyProtocol, localSNIs: make(map[string]struct{}), localOverrides: overridesMap, activeTunnels: make(map[string]*activeTunnel), @@ -265,6 +323,17 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort) + // Send PROXY protocol header if enabled + if p.proxyProtocol { + proxyHeader := buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr()) + logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader)) + + if _, err := targetConn.Write([]byte(proxyHeader)); err != nil { + logger.Debug("Failed to send PROXY protocol header: %v", err) + return + } + } + // Track this tunnel by SNI p.activeTunnelsLock.Lock() tunnel, ok := p.activeTunnels[hostname] diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go new file mode 100644 index 0000000..d585610 --- /dev/null +++ b/proxy/proxy_test.go @@ -0,0 +1,78 @@ +package proxy + +import ( + "net" + "testing" +) + +func TestBuildProxyProtocolHeader(t *testing.T) { + tests := []struct { + name string + clientAddr string + targetAddr string + expected string + }{ + { + name: "IPv4 client and target", + clientAddr: "192.168.1.100:12345", + targetAddr: "10.0.0.1:443", + expected: "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n", + }, + { + name: "IPv6 client and target", + clientAddr: "[2001:db8::1]:12345", + targetAddr: "[2001:db8::2]:443", + expected: "PROXY TCP6 2001:db8::1 2001:db8::2 12345 443\r\n", + }, + { + name: "IPv4 client with IPv6 loopback target", + clientAddr: "192.168.1.100:12345", + targetAddr: "[::1]:443", + expected: "PROXY TCP4 192.168.1.100 127.0.0.1 12345 443\r\n", + }, + { + name: "IPv4 client with IPv6 target", + clientAddr: "192.168.1.100:12345", + targetAddr: "[2001:db8::2]:443", + expected: "PROXY TCP4 192.168.1.100 127.0.0.1 12345 443\r\n", + }, + { + name: "IPv6 client with IPv4 target", + clientAddr: "[2001:db8::1]:12345", + targetAddr: "10.0.0.1:443", + expected: "PROXY TCP6 2001:db8::1 ::ffff:10.0.0.1 12345 443\r\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clientTCP, err := net.ResolveTCPAddr("tcp", tt.clientAddr) + if err != nil { + t.Fatalf("Failed to resolve client address: %v", err) + } + + targetTCP, err := net.ResolveTCPAddr("tcp", tt.targetAddr) + if err != nil { + t.Fatalf("Failed to resolve target address: %v", err) + } + + result := buildProxyProtocolHeader(clientTCP, targetTCP) + if result != tt.expected { + t.Errorf("Expected %q, got %q", tt.expected, result) + } + }) + } +} + +func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) { + // Test with non-TCP address type + clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345} + targetAddr := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443} + + result := buildProxyProtocolHeader(clientAddr, targetAddr) + expected := "PROXY UNKNOWN\r\n" + + if result != expected { + t.Errorf("Expected %q, got %q", expected, result) + } +}