mirror of
https://github.com/fosrl/gerbil.git
synced 2026-02-07 21:46:40 +00:00
Add proxy protocol
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -42,6 +43,7 @@ type SNIProxy struct {
|
||||
localProxyPort int
|
||||
remoteConfigURL string
|
||||
publicKey string
|
||||
proxyProtocol bool // Enable PROXY protocol v1
|
||||
|
||||
// New fields for fast local SNI lookup
|
||||
localSNIs map[string]struct{}
|
||||
@@ -73,8 +75,63 @@ func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
// buildProxyProtocolHeader creates a PROXY protocol v1 header
|
||||
func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
|
||||
clientTCP, ok := clientAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
// Fallback for unknown address types
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
targetTCP, ok := targetAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
// Fallback for unknown address types
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
// Determine protocol family based on client IP and normalize target IP accordingly
|
||||
var protocol string
|
||||
var targetIP string
|
||||
|
||||
if clientTCP.IP.To4() != nil {
|
||||
// Client is IPv4, use TCP4 protocol
|
||||
protocol = "TCP4"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
// Target is also IPv4, use as-is
|
||||
targetIP = targetTCP.IP.String()
|
||||
} else {
|
||||
// Target is IPv6, but we need IPv4 for consistent protocol family
|
||||
// Use the IPv4 loopback if target is IPv6 loopback, otherwise use 127.0.0.1
|
||||
if targetTCP.IP.IsLoopback() {
|
||||
targetIP = "127.0.0.1"
|
||||
} else {
|
||||
// For non-loopback IPv6 targets, we could try to extract embedded IPv4
|
||||
// or fall back to a sensible IPv4 address based on the target
|
||||
targetIP = "127.0.0.1" // Safe fallback
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Client is IPv6, use TCP6 protocol
|
||||
protocol = "TCP6"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
// Target is IPv4, convert to IPv6 representation
|
||||
targetIP = "::ffff:" + targetTCP.IP.String()
|
||||
} else {
|
||||
// Target is also IPv6, use as-is
|
||||
targetIP = targetTCP.IP.String()
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
|
||||
protocol,
|
||||
clientTCP.IP.String(),
|
||||
targetIP,
|
||||
clientTCP.Port,
|
||||
targetTCP.Port)
|
||||
}
|
||||
|
||||
// NewSNIProxy creates a new SNI proxy instance
|
||||
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string) (*SNIProxy, error) {
|
||||
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool) (*SNIProxy, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Create local overrides map
|
||||
@@ -94,6 +151,7 @@ func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, lo
|
||||
localProxyPort: localProxyPort,
|
||||
remoteConfigURL: remoteConfigURL,
|
||||
publicKey: publicKey,
|
||||
proxyProtocol: proxyProtocol,
|
||||
localSNIs: make(map[string]struct{}),
|
||||
localOverrides: overridesMap,
|
||||
activeTunnels: make(map[string]*activeTunnel),
|
||||
@@ -265,6 +323,17 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
|
||||
logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort)
|
||||
|
||||
// Send PROXY protocol header if enabled
|
||||
if p.proxyProtocol {
|
||||
proxyHeader := buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr())
|
||||
logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader))
|
||||
|
||||
if _, err := targetConn.Write([]byte(proxyHeader)); err != nil {
|
||||
logger.Debug("Failed to send PROXY protocol header: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Track this tunnel by SNI
|
||||
p.activeTunnelsLock.Lock()
|
||||
tunnel, ok := p.activeTunnels[hostname]
|
||||
|
||||
78
proxy/proxy_test.go
Normal file
78
proxy/proxy_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildProxyProtocolHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientAddr string
|
||||
targetAddr string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "IPv4 client and target",
|
||||
clientAddr: "192.168.1.100:12345",
|
||||
targetAddr: "10.0.0.1:443",
|
||||
expected: "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n",
|
||||
},
|
||||
{
|
||||
name: "IPv6 client and target",
|
||||
clientAddr: "[2001:db8::1]:12345",
|
||||
targetAddr: "[2001:db8::2]:443",
|
||||
expected: "PROXY TCP6 2001:db8::1 2001:db8::2 12345 443\r\n",
|
||||
},
|
||||
{
|
||||
name: "IPv4 client with IPv6 loopback target",
|
||||
clientAddr: "192.168.1.100:12345",
|
||||
targetAddr: "[::1]:443",
|
||||
expected: "PROXY TCP4 192.168.1.100 127.0.0.1 12345 443\r\n",
|
||||
},
|
||||
{
|
||||
name: "IPv4 client with IPv6 target",
|
||||
clientAddr: "192.168.1.100:12345",
|
||||
targetAddr: "[2001:db8::2]:443",
|
||||
expected: "PROXY TCP4 192.168.1.100 127.0.0.1 12345 443\r\n",
|
||||
},
|
||||
{
|
||||
name: "IPv6 client with IPv4 target",
|
||||
clientAddr: "[2001:db8::1]:12345",
|
||||
targetAddr: "10.0.0.1:443",
|
||||
expected: "PROXY TCP6 2001:db8::1 ::ffff:10.0.0.1 12345 443\r\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clientTCP, err := net.ResolveTCPAddr("tcp", tt.clientAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to resolve client address: %v", err)
|
||||
}
|
||||
|
||||
targetTCP, err := net.ResolveTCPAddr("tcp", tt.targetAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to resolve target address: %v", err)
|
||||
}
|
||||
|
||||
result := buildProxyProtocolHeader(clientTCP, targetTCP)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %q, got %q", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) {
|
||||
// Test with non-TCP address type
|
||||
clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345}
|
||||
targetAddr := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443}
|
||||
|
||||
result := buildProxyProtocolHeader(clientAddr, targetAddr)
|
||||
expected := "PROXY UNKNOWN\r\n"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected %q, got %q", expected, result)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user