Accept proxy protocol from other nodes

This commit is contained in:
Owen
2025-09-29 21:56:15 -07:00
parent 92992b8c14
commit 9038239bbe
3 changed files with 293 additions and 23 deletions

16
main.go
View File

@@ -121,6 +121,7 @@ func main() {
localProxyAddr string localProxyAddr string
localProxyPort int localProxyPort int
localOverridesStr string localOverridesStr string
trustedUpstreamsStr string
proxyProtocol bool proxyProtocol bool
) )
@@ -138,6 +139,7 @@ func main() {
localProxyAddr = os.Getenv("LOCAL_PROXY") localProxyAddr = os.Getenv("LOCAL_PROXY")
localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT") localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT")
localOverridesStr = os.Getenv("LOCAL_OVERRIDES") localOverridesStr = os.Getenv("LOCAL_OVERRIDES")
trustedUpstreamsStr = os.Getenv("TRUSTED_UPSTREAMS")
proxyProtocolStr := os.Getenv("PROXY_PROTOCOL") proxyProtocolStr := os.Getenv("PROXY_PROTOCOL")
if interfaceName == "" { if interfaceName == "" {
@@ -197,6 +199,9 @@ func main() {
if localOverridesStr != "" { if localOverridesStr != "" {
flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy") flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy")
} }
if trustedUpstreamsStr == "" {
flag.StringVar(&trustedUpstreamsStr, "trusted-upstreams", "", "Comma-separated list of trusted upstream proxy domain names/IPs that can send PROXY protocol")
}
if proxyProtocolStr != "" { if proxyProtocolStr != "" {
proxyProtocol = strings.ToLower(proxyProtocolStr) == "true" proxyProtocol = strings.ToLower(proxyProtocolStr) == "true"
@@ -323,7 +328,16 @@ func main() {
logger.Info("Local overrides configured: %v", localOverrides) logger.Info("Local overrides configured: %v", localOverrides)
} }
proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides, proxyProtocol) var trustedUpstreams []string
if trustedUpstreamsStr != "" {
trustedUpstreams = strings.Split(trustedUpstreamsStr, ",")
for i, upstream := range trustedUpstreams {
trustedUpstreams[i] = strings.TrimSpace(upstream)
}
logger.Info("Trusted upstreams configured: %v", trustedUpstreams)
}
proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides, proxyProtocol, trustedUpstreams)
if err != nil { if err != nil {
logger.Fatal("Failed to create proxy: %v", err) logger.Fatal("Failed to create proxy: %v", err)
} }

View File

@@ -11,6 +11,7 @@ import (
"log" "log"
"net" "net"
"net/http" "net/http"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -31,6 +32,16 @@ type RouteAPIResponse struct {
Endpoints []string `json:"endpoints"` Endpoints []string `json:"endpoints"`
} }
// ProxyProtocolInfo holds information parsed from incoming PROXY protocol header
type ProxyProtocolInfo struct {
Protocol string // TCP4 or TCP6
SrcIP string
DestIP string
SrcPort int
DestPort int
OriginalConn net.Conn // The original connection after PROXY protocol parsing
}
// SNIProxy represents the main proxy server // SNIProxy represents the main proxy server
type SNIProxy struct { type SNIProxy struct {
port int port int
@@ -55,6 +66,9 @@ type SNIProxy struct {
// Track active tunnels by SNI // Track active tunnels by SNI
activeTunnels map[string]*activeTunnel activeTunnels map[string]*activeTunnel
activeTunnelsLock sync.Mutex activeTunnelsLock sync.Mutex
// Trusted upstream proxies that can send PROXY protocol
trustedUpstreams map[string]struct{}
} }
type activeTunnel struct { type activeTunnel struct {
@@ -75,6 +89,159 @@ func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
// parseProxyProtocolHeader parses a PROXY protocol v1 header from the connection
func (p *SNIProxy) parseProxyProtocolHeader(conn net.Conn) (*ProxyProtocolInfo, net.Conn, error) {
// Check if the connection comes from a trusted upstream
remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
return nil, conn, fmt.Errorf("failed to parse remote address: %w", err)
}
// Resolve the remote IP to hostname to check if it's trusted
// For simplicity, we'll check the IP directly in trusted upstreams
// In production, you might want to do reverse DNS lookup
if _, isTrusted := p.trustedUpstreams[remoteHost]; !isTrusted {
// Not from trusted upstream, return original connection
return nil, conn, nil
}
// Set read timeout for PROXY protocol parsing
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
return nil, conn, fmt.Errorf("failed to set read deadline: %w", err)
}
// Read the first line (PROXY protocol header)
buffer := make([]byte, 512) // PROXY protocol header should be much smaller
n, err := conn.Read(buffer)
if err != nil {
return nil, conn, fmt.Errorf("failed to read PROXY protocol header: %w", err)
}
// Find the end of the first line (CRLF)
headerEnd := bytes.Index(buffer[:n], []byte("\r\n"))
if headerEnd == -1 {
return nil, conn, fmt.Errorf("PROXY protocol header not found")
}
headerLine := string(buffer[:headerEnd])
remainingData := buffer[headerEnd+2 : n]
// Parse PROXY protocol line: "PROXY TCP4/TCP6 srcIP destIP srcPort destPort"
parts := strings.Fields(headerLine)
if len(parts) != 6 || parts[0] != "PROXY" {
// Check for PROXY UNKNOWN
if len(parts) == 2 && parts[0] == "PROXY" && parts[1] == "UNKNOWN" {
// PROXY UNKNOWN - use original connection info
return nil, conn, nil
}
return nil, conn, fmt.Errorf("invalid PROXY protocol header: %s", headerLine)
}
protocol := parts[1]
srcIP := parts[2]
destIP := parts[3]
srcPort, err := strconv.Atoi(parts[4])
if err != nil {
return nil, conn, fmt.Errorf("invalid source port in PROXY header: %s", parts[4])
}
destPort, err := strconv.Atoi(parts[5])
if err != nil {
return nil, conn, fmt.Errorf("invalid destination port in PROXY header: %s", parts[5])
}
// Create a new reader that includes remaining data + original connection
var newReader io.Reader
if len(remainingData) > 0 {
newReader = io.MultiReader(bytes.NewReader(remainingData), conn)
} else {
newReader = conn
}
// Create a wrapper connection that reads from the combined reader
wrappedConn := &proxyProtocolConn{
Conn: conn,
reader: newReader,
}
proxyInfo := &ProxyProtocolInfo{
Protocol: protocol,
SrcIP: srcIP,
DestIP: destIP,
SrcPort: srcPort,
DestPort: destPort,
OriginalConn: wrappedConn,
}
// Clear read timeout
if err := conn.SetReadDeadline(time.Time{}); err != nil {
return nil, conn, fmt.Errorf("failed to clear read deadline: %w", err)
}
return proxyInfo, wrappedConn, nil
}
// proxyProtocolConn wraps a connection to read from a custom reader
type proxyProtocolConn struct {
net.Conn
reader io.Reader
}
func (c *proxyProtocolConn) Read(b []byte) (int, error) {
return c.reader.Read(b)
}
// buildProxyProtocolHeaderFromInfo creates a PROXY protocol v1 header using ProxyProtocolInfo
func (p *SNIProxy) buildProxyProtocolHeaderFromInfo(proxyInfo *ProxyProtocolInfo, targetAddr net.Addr) string {
targetTCP, ok := targetAddr.(*net.TCPAddr)
if !ok {
// Fallback for unknown address types
return "PROXY UNKNOWN\r\n"
}
// Use the original client information from the PROXY protocol
var targetIP string
var protocol string
// Parse source IP to determine protocol family
srcIP := net.ParseIP(proxyInfo.SrcIP)
if srcIP == nil {
return "PROXY UNKNOWN\r\n"
}
if srcIP.To4() != nil {
// Source is IPv4, use TCP4 protocol
protocol = "TCP4"
if targetTCP.IP.To4() != nil {
// Target is also IPv4, use as-is
targetIP = targetTCP.IP.String()
} else {
// Target is IPv6, but we need IPv4 for consistent protocol family
if targetTCP.IP.IsLoopback() {
targetIP = "127.0.0.1"
} else {
targetIP = "127.0.0.1" // Safe fallback
}
}
} else {
// Source is IPv6, use TCP6 protocol
protocol = "TCP6"
if targetTCP.IP.To4() != nil {
// Target is IPv4, convert to IPv6 representation
targetIP = "::ffff:" + targetTCP.IP.String()
} else {
// Target is also IPv6, use as-is
targetIP = targetTCP.IP.String()
}
}
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
protocol,
proxyInfo.SrcIP,
targetIP,
proxyInfo.SrcPort,
targetTCP.Port)
}
// buildProxyProtocolHeader creates a PROXY protocol v1 header // buildProxyProtocolHeader creates a PROXY protocol v1 header
func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string { func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
clientTCP, ok := clientAddr.(*net.TCPAddr) clientTCP, ok := clientAddr.(*net.TCPAddr)
@@ -131,7 +298,7 @@ func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
} }
// NewSNIProxy creates a new SNI proxy instance // NewSNIProxy creates a new SNI proxy instance
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool) (*SNIProxy, error) { func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool, trustedUpstreams []string) (*SNIProxy, error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
// Create local overrides map // Create local overrides map
@@ -142,19 +309,36 @@ func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, lo
} }
} }
// Create trusted upstreams map
trustedMap := make(map[string]struct{})
for _, upstream := range trustedUpstreams {
if upstream != "" {
// Add both the domain and potentially resolved IPs
trustedMap[upstream] = struct{}{}
// Try to resolve the domain to IPs and add them too
if ips, err := net.LookupIP(upstream); err == nil {
for _, ip := range ips {
trustedMap[ip.String()] = struct{}{}
}
}
}
}
proxy := &SNIProxy{ proxy := &SNIProxy{
port: port, port: port,
cache: cache.New(3*time.Second, 10*time.Minute), cache: cache.New(3*time.Second, 10*time.Minute),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
localProxyAddr: localProxyAddr, localProxyAddr: localProxyAddr,
localProxyPort: localProxyPort, localProxyPort: localProxyPort,
remoteConfigURL: remoteConfigURL, remoteConfigURL: remoteConfigURL,
publicKey: publicKey, publicKey: publicKey,
proxyProtocol: proxyProtocol, proxyProtocol: proxyProtocol,
localSNIs: make(map[string]struct{}), localSNIs: make(map[string]struct{}),
localOverrides: overridesMap, localOverrides: overridesMap,
activeTunnels: make(map[string]*activeTunnel), activeTunnels: make(map[string]*activeTunnel),
trustedUpstreams: trustedMap,
} }
return proxy, nil return proxy, nil
@@ -270,14 +454,31 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
logger.Debug("Accepted connection from %s", clientConn.RemoteAddr()) logger.Debug("Accepted connection from %s", clientConn.RemoteAddr())
// Check for PROXY protocol from trusted upstream
var proxyInfo *ProxyProtocolInfo
var actualClientConn net.Conn = clientConn
if len(p.trustedUpstreams) > 0 {
var err error
proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn)
if err != nil {
logger.Debug("Failed to parse PROXY protocol: %v", err)
return
}
if proxyInfo != nil {
logger.Debug("Received PROXY protocol from trusted upstream: %s:%d -> %s:%d",
proxyInfo.SrcIP, proxyInfo.SrcPort, proxyInfo.DestIP, proxyInfo.DestPort)
}
}
// Set read timeout for SNI extraction // Set read timeout for SNI extraction
if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := actualClientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
logger.Debug("Failed to set read deadline: %v", err) logger.Debug("Failed to set read deadline: %v", err)
return return
} }
// Extract SNI hostname // Extract SNI hostname
hostname, clientReader, err := p.extractSNI(clientConn) hostname, clientReader, err := p.extractSNI(actualClientConn)
if err != nil { if err != nil {
logger.Debug("SNI extraction failed: %v", err) logger.Debug("SNI extraction failed: %v", err)
return return
@@ -291,13 +492,20 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
logger.Debug("SNI hostname detected: %s", hostname) logger.Debug("SNI hostname detected: %s", hostname)
// Remove read timeout for normal operation // Remove read timeout for normal operation
if err := clientConn.SetReadDeadline(time.Time{}); err != nil { if err := actualClientConn.SetReadDeadline(time.Time{}); err != nil {
logger.Debug("Failed to clear read deadline: %v", err) logger.Debug("Failed to clear read deadline: %v", err)
return return
} }
// Get routing information // Get routing information - use original client address if available from PROXY protocol
route, err := p.getRoute(hostname, clientConn.RemoteAddr().String()) var clientAddrStr string
if proxyInfo != nil {
clientAddrStr = fmt.Sprintf("%s:%d", proxyInfo.SrcIP, proxyInfo.SrcPort)
} else {
clientAddrStr = clientConn.RemoteAddr().String()
}
route, err := p.getRoute(hostname, clientAddrStr)
if err != nil { if err != nil {
logger.Debug("Failed to get route for %s: %v", hostname, err) logger.Debug("Failed to get route for %s: %v", hostname, err)
return return
@@ -325,7 +533,14 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
// Send PROXY protocol header if enabled // Send PROXY protocol header if enabled
if p.proxyProtocol { if p.proxyProtocol {
proxyHeader := buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr()) var proxyHeader string
if proxyInfo != nil {
// Use original client info from PROXY protocol
proxyHeader = p.buildProxyProtocolHeaderFromInfo(proxyInfo, targetConn.LocalAddr())
} else {
// Use direct client connection info
proxyHeader = buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr())
}
logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader)) logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader))
if _, err := targetConn.Write([]byte(proxyHeader)); err != nil { if _, err := targetConn.Write([]byte(proxyHeader)); err != nil {
@@ -341,7 +556,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
tunnel = &activeTunnel{} tunnel = &activeTunnel{}
p.activeTunnels[hostname] = tunnel p.activeTunnels[hostname] = tunnel
} }
tunnel.conns = append(tunnel.conns, clientConn) tunnel.conns = append(tunnel.conns, actualClientConn)
p.activeTunnelsLock.Unlock() p.activeTunnelsLock.Unlock()
defer func() { defer func() {
@@ -350,7 +565,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
if tunnel, ok := p.activeTunnels[hostname]; ok { if tunnel, ok := p.activeTunnels[hostname]; ok {
newConns := make([]net.Conn, 0, len(tunnel.conns)) newConns := make([]net.Conn, 0, len(tunnel.conns))
for _, c := range tunnel.conns { for _, c := range tunnel.conns {
if c != clientConn { if c != actualClientConn {
newConns = append(newConns, c) newConns = append(newConns, c)
} }
} }
@@ -364,7 +579,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
}() }()
// Start bidirectional data transfer // Start bidirectional data transfer
p.pipe(clientConn, targetConn, clientReader) p.pipe(actualClientConn, targetConn, clientReader)
} }
// getRoute retrieves routing information for a hostname // getRoute retrieves routing information for a hostname

View File

@@ -76,3 +76,44 @@ func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) {
t.Errorf("Expected %q, got %q", expected, result) t.Errorf("Expected %q, got %q", expected, result)
} }
} }
func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) {
proxy, err := NewSNIProxy(8443, "", "", "127.0.0.1", 443, nil, true, nil)
if err != nil {
t.Fatalf("Failed to create SNI proxy: %v", err)
}
// Test IPv4 case
proxyInfo := &ProxyProtocolInfo{
Protocol: "TCP4",
SrcIP: "10.0.0.1",
DestIP: "192.168.1.100",
SrcPort: 12345,
DestPort: 443,
}
targetAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8080")
header := proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr)
expected := "PROXY TCP4 10.0.0.1 127.0.0.1 12345 8080\r\n"
if header != expected {
t.Errorf("Expected header '%s', got '%s'", expected, header)
}
// Test IPv6 case
proxyInfo = &ProxyProtocolInfo{
Protocol: "TCP6",
SrcIP: "2001:db8::1",
DestIP: "2001:db8::2",
SrcPort: 12345,
DestPort: 443,
}
targetAddr, _ = net.ResolveTCPAddr("tcp6", "[::1]:8080")
header = proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr)
expected = "PROXY TCP6 2001:db8::1 ::1 12345 8080\r\n"
if header != expected {
t.Errorf("Expected header '%s', got '%s'", expected, header)
}
}