Add pp to relay

This commit is contained in:
Owen
2026-03-27 17:21:44 -07:00
parent 40da38708c
commit 9ce372e644
5 changed files with 573 additions and 300 deletions

View File

@@ -11,12 +11,12 @@ import (
"log"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/fosrl/gerbil/logger"
"github.com/fosrl/gerbil/proxyproto"
"github.com/patrickmn/go-cache"
)
@@ -32,16 +32,6 @@ type RouteAPIResponse struct {
Endpoints []string `json:"endpoints"`
}
// ProxyProtocolInfo holds information parsed from incoming PROXY protocol header
type ProxyProtocolInfo struct {
Protocol string // TCP4 or TCP6
SrcIP string
DestIP string
SrcPort int
DestPort int
OriginalConn net.Conn // The original connection after PROXY protocol parsing
}
// SNIProxy represents the main proxy server
type SNIProxy struct {
port int
@@ -89,249 +79,6 @@ func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
// parseProxyProtocolHeader parses a PROXY protocol v1 header from the connection
func (p *SNIProxy) parseProxyProtocolHeader(conn net.Conn) (*ProxyProtocolInfo, net.Conn, error) {
// Check if the connection comes from a trusted upstream
remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
return nil, conn, fmt.Errorf("failed to parse remote address: %w", err)
}
// Resolve the remote IP to hostname to check if it's trusted
// For simplicity, we'll check the IP directly in trusted upstreams
// In production, you might want to do reverse DNS lookup
if _, isTrusted := p.trustedUpstreams[remoteHost]; !isTrusted {
// Not from trusted upstream, return original connection
return nil, conn, nil
}
// Set read timeout for PROXY protocol parsing
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
return nil, conn, fmt.Errorf("failed to set read deadline: %w", err)
}
// Read the first line (PROXY protocol header)
buffer := make([]byte, 512) // PROXY protocol header should be much smaller
n, err := conn.Read(buffer)
if err != nil {
// If we can't read from trusted upstream, treat as regular connection
logger.Debug("Could not read from trusted upstream %s, treating as regular connection: %v", remoteHost, err)
// Clear read timeout before returning
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
logger.Debug("Failed to clear read deadline: %v", clearErr)
}
return nil, conn, nil
}
// Find the end of the first line (CRLF)
headerEnd := bytes.Index(buffer[:n], []byte("\r\n"))
if headerEnd == -1 {
// No PROXY protocol header found, treat as regular TLS connection
// Return the connection with the buffered data prepended
logger.Debug("No PROXY protocol header from trusted upstream %s, treating as regular TLS connection", remoteHost)
// Clear read timeout
if err := conn.SetReadDeadline(time.Time{}); err != nil {
logger.Debug("Failed to clear read deadline: %v", err)
}
// Create a reader that includes the buffered data + original connection
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
wrappedConn := &proxyProtocolConn{
Conn: conn,
reader: newReader,
}
return nil, wrappedConn, nil
}
headerLine := string(buffer[:headerEnd])
remainingData := buffer[headerEnd+2 : n]
// Parse PROXY protocol line: "PROXY TCP4/TCP6 srcIP destIP srcPort destPort"
parts := strings.Fields(headerLine)
if len(parts) != 6 || parts[0] != "PROXY" {
// Check for PROXY UNKNOWN
if len(parts) == 2 && parts[0] == "PROXY" && parts[1] == "UNKNOWN" {
// PROXY UNKNOWN - use original connection info
return nil, conn, nil
}
// Invalid PROXY protocol, but might be regular TLS - treat as such
logger.Debug("Invalid PROXY protocol from trusted upstream %s, treating as regular TLS connection: %s", remoteHost, headerLine)
// Clear read timeout
if err := conn.SetReadDeadline(time.Time{}); err != nil {
logger.Debug("Failed to clear read deadline: %v", err)
}
// Return the connection with all buffered data prepended
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
wrappedConn := &proxyProtocolConn{
Conn: conn,
reader: newReader,
}
return nil, wrappedConn, nil
}
protocol := parts[1]
srcIP := parts[2]
destIP := parts[3]
srcPort, err := strconv.Atoi(parts[4])
if err != nil {
return nil, conn, fmt.Errorf("invalid source port in PROXY header: %s", parts[4])
}
destPort, err := strconv.Atoi(parts[5])
if err != nil {
return nil, conn, fmt.Errorf("invalid destination port in PROXY header: %s", parts[5])
}
// Create a new reader that includes remaining data + original connection
var newReader io.Reader
if len(remainingData) > 0 {
newReader = io.MultiReader(bytes.NewReader(remainingData), conn)
} else {
newReader = conn
}
// Create a wrapper connection that reads from the combined reader
wrappedConn := &proxyProtocolConn{
Conn: conn,
reader: newReader,
}
proxyInfo := &ProxyProtocolInfo{
Protocol: protocol,
SrcIP: srcIP,
DestIP: destIP,
SrcPort: srcPort,
DestPort: destPort,
OriginalConn: wrappedConn,
}
// Clear read timeout
if err := conn.SetReadDeadline(time.Time{}); err != nil {
return nil, conn, fmt.Errorf("failed to clear read deadline: %w", err)
}
return proxyInfo, wrappedConn, nil
}
// proxyProtocolConn wraps a connection to read from a custom reader
type proxyProtocolConn struct {
net.Conn
reader io.Reader
}
func (c *proxyProtocolConn) Read(b []byte) (int, error) {
return c.reader.Read(b)
}
// buildProxyProtocolHeaderFromInfo creates a PROXY protocol v1 header using ProxyProtocolInfo
func (p *SNIProxy) buildProxyProtocolHeaderFromInfo(proxyInfo *ProxyProtocolInfo, targetAddr net.Addr) string {
targetTCP, ok := targetAddr.(*net.TCPAddr)
if !ok {
// Fallback for unknown address types
return "PROXY UNKNOWN\r\n"
}
// Use the original client information from the PROXY protocol
var targetIP string
var protocol string
// Parse source IP to determine protocol family
srcIP := net.ParseIP(proxyInfo.SrcIP)
if srcIP == nil {
return "PROXY UNKNOWN\r\n"
}
if srcIP.To4() != nil {
// Source is IPv4, use TCP4 protocol
protocol = "TCP4"
if targetTCP.IP.To4() != nil {
// Target is also IPv4, use as-is
targetIP = targetTCP.IP.String()
} else {
// Target is IPv6, but we need IPv4 for consistent protocol family
if targetTCP.IP.IsLoopback() {
targetIP = "127.0.0.1"
} else {
targetIP = "127.0.0.1" // Safe fallback
}
}
} else {
// Source is IPv6, use TCP6 protocol
protocol = "TCP6"
if targetTCP.IP.To4() != nil {
// Target is IPv4, convert to IPv6 representation
targetIP = "::ffff:" + targetTCP.IP.String()
} else {
// Target is also IPv6, use as-is
targetIP = targetTCP.IP.String()
}
}
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
protocol,
proxyInfo.SrcIP,
targetIP,
proxyInfo.SrcPort,
targetTCP.Port)
}
// buildProxyProtocolHeader creates a PROXY protocol v1 header
func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
clientTCP, ok := clientAddr.(*net.TCPAddr)
if !ok {
// Fallback for unknown address types
return "PROXY UNKNOWN\r\n"
}
targetTCP, ok := targetAddr.(*net.TCPAddr)
if !ok {
// Fallback for unknown address types
return "PROXY UNKNOWN\r\n"
}
// Determine protocol family based on client IP and normalize target IP accordingly
var protocol string
var targetIP string
if clientTCP.IP.To4() != nil {
// Client is IPv4, use TCP4 protocol
protocol = "TCP4"
if targetTCP.IP.To4() != nil {
// Target is also IPv4, use as-is
targetIP = targetTCP.IP.String()
} else {
// Target is IPv6, but we need IPv4 for consistent protocol family
// Use the IPv4 loopback if target is IPv6 loopback, otherwise use 127.0.0.1
if targetTCP.IP.IsLoopback() {
targetIP = "127.0.0.1"
} else {
// For non-loopback IPv6 targets, we could try to extract embedded IPv4
// or fall back to a sensible IPv4 address based on the target
targetIP = "127.0.0.1" // Safe fallback
}
}
} else {
// Client is IPv6, use TCP6 protocol
protocol = "TCP6"
if targetTCP.IP.To4() != nil {
// Target is IPv4, convert to IPv6 representation
targetIP = "::ffff:" + targetTCP.IP.String()
} else {
// Target is also IPv6, use as-is
targetIP = targetTCP.IP.String()
}
}
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
protocol,
clientTCP.IP.String(),
targetIP,
clientTCP.Port,
targetTCP.Port)
}
// NewSNIProxy creates a new SNI proxy instance
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool, trustedUpstreams []string) (*SNIProxy, error) {
ctx, cancel := context.WithCancel(context.Background())
@@ -490,12 +237,12 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
logger.Debug("Accepted connection from %s", clientConn.RemoteAddr())
// Check for PROXY protocol from trusted upstream
var proxyInfo *ProxyProtocolInfo
var proxyInfo *proxyproto.Info
var actualClientConn net.Conn = clientConn
if len(p.trustedUpstreams) > 0 {
var err error
proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn)
proxyInfo, actualClientConn, err = proxyproto.ParseV1Header(clientConn, p.trustedUpstreams)
if err != nil {
logger.Debug("Failed to parse PROXY protocol: %v", err)
return
@@ -575,10 +322,10 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
var proxyHeader string
if proxyInfo != nil {
// Use original client info from PROXY protocol
proxyHeader = p.buildProxyProtocolHeaderFromInfo(proxyInfo, targetConn.LocalAddr())
proxyHeader = proxyproto.BuildV1HeaderFromInfo(proxyInfo, targetConn.LocalAddr())
} else {
// Use direct client connection info
proxyHeader = buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr())
proxyHeader = proxyproto.BuildV1Header(clientConn.RemoteAddr(), targetConn.LocalAddr())
}
logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader))

View File

@@ -3,6 +3,8 @@ package proxy
import (
"net"
"testing"
"github.com/fosrl/gerbil/proxyproto"
)
func TestBuildProxyProtocolHeader(t *testing.T) {
@@ -56,7 +58,7 @@ func TestBuildProxyProtocolHeader(t *testing.T) {
t.Fatalf("Failed to resolve target address: %v", err)
}
result := buildProxyProtocolHeader(clientTCP, targetTCP)
result := proxyproto.BuildV1Header(clientTCP, targetTCP)
if result != tt.expected {
t.Errorf("Expected %q, got %q", tt.expected, result)
}
@@ -69,7 +71,7 @@ func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) {
clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345}
targetAddr := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443}
result := buildProxyProtocolHeader(clientAddr, targetAddr)
result := proxyproto.BuildV1Header(clientAddr, targetAddr)
expected := "PROXY UNKNOWN\r\n"
if result != expected {
@@ -78,13 +80,8 @@ func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) {
}
func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) {
proxy, err := NewSNIProxy(8443, "", "", "127.0.0.1", 443, nil, true, nil)
if err != nil {
t.Fatalf("Failed to create SNI proxy: %v", err)
}
// Test IPv4 case
proxyInfo := &ProxyProtocolInfo{
info := &proxyproto.Info{
Protocol: "TCP4",
SrcIP: "10.0.0.1",
DestIP: "192.168.1.100",
@@ -93,7 +90,7 @@ func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) {
}
targetAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8080")
header := proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr)
header := proxyproto.BuildV1HeaderFromInfo(info, targetAddr)
expected := "PROXY TCP4 10.0.0.1 127.0.0.1 12345 8080\r\n"
if header != expected {
@@ -101,7 +98,7 @@ func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) {
}
// Test IPv6 case
proxyInfo = &ProxyProtocolInfo{
info = &proxyproto.Info{
Protocol: "TCP6",
SrcIP: "2001:db8::1",
DestIP: "2001:db8::2",
@@ -110,10 +107,99 @@ func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) {
}
targetAddr, _ = net.ResolveTCPAddr("tcp6", "[::1]:8080")
header = proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr)
header = proxyproto.BuildV1HeaderFromInfo(info, targetAddr)
expected = "PROXY TCP6 2001:db8::1 ::1 12345 8080\r\n"
if header != expected {
t.Errorf("Expected header '%s', got '%s'", expected, header)
}
}
func TestParseV2UDPHeader(t *testing.T) {
// Build a minimal PROXY v2 header for IPv4 UDP
// Magic (12) + ver/cmd (1) + fam/proto (1) + len (2) + src IP (4) + dst IP (4) + src port (2) + dst port (2) = 28 bytes
header := []byte{
// Magic signature
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
// Version 2 (0x2x), PROXY command (0x01)
0x21,
// AF_INET (0x1x), DGRAM/UDP (0x02)
0x12,
// Address block length: 12 bytes (4+4+2+2)
0x00, 0x0C,
// Source IP: 192.168.1.100
192, 168, 1, 100,
// Destination IP: 10.0.0.1
10, 0, 0, 1,
// Source port: 4500
0x11, 0x94,
// Destination port: 21820
0x55, 0x3C,
}
// Append a fake application payload
payload := []byte{0x01, 0x02, 0x03}
data := append(header, payload...)
info, remaining, ok := proxyproto.ParseV2UDPHeader(data)
if !ok {
t.Fatal("Expected ParseV2UDPHeader to return ok=true")
}
if info == nil {
t.Fatal("Expected non-nil Info")
}
if info.Protocol != "UDP4" {
t.Errorf("Expected protocol UDP4, got %s", info.Protocol)
}
if info.SrcIP != "192.168.1.100" {
t.Errorf("Expected SrcIP 192.168.1.100, got %s", info.SrcIP)
}
if info.DestIP != "10.0.0.1" {
t.Errorf("Expected DestIP 10.0.0.1, got %s", info.DestIP)
}
if info.SrcPort != 4500 {
t.Errorf("Expected SrcPort 4500, got %d", info.SrcPort)
}
if info.DestPort != 21820 {
t.Errorf("Expected DestPort 21820, got %d", info.DestPort)
}
if len(remaining) != len(payload) {
t.Errorf("Expected %d remaining bytes, got %d", len(payload), len(remaining))
}
}
func TestParseV2UDPHeaderNoHeader(t *testing.T) {
// Data that does NOT start with v2 magic should be returned as-is
data := []byte{0x01, 0x02, 0x03}
info, remaining, ok := proxyproto.ParseV2UDPHeader(data)
if ok {
t.Error("Expected ok=false for non-v2 data")
}
if info != nil {
t.Error("Expected nil Info for non-v2 data")
}
if len(remaining) != len(data) {
t.Errorf("Expected remaining to equal original data length %d, got %d", len(data), len(remaining))
}
}
func TestIsV2Header(t *testing.T) {
valid := []byte{
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
// extra bytes beyond the magic
0x21, 0x12,
}
if !proxyproto.IsV2Header(valid) {
t.Error("Expected IsV2Header=true for valid magic")
}
invalid := []byte{0x01, 0x02, 0x03}
if proxyproto.IsV2Header(invalid) {
t.Error("Expected IsV2Header=false for non-magic data")
}
tooShort := []byte{0x0D, 0x0A}
if proxyproto.IsV2Header(tooShort) {
t.Error("Expected IsV2Header=false for too-short data")
}
}