mirror of
https://github.com/fosrl/gerbil.git
synced 2026-02-07 21:46:40 +00:00
Accept proxy protocol from other nodes
This commit is contained in:
16
main.go
16
main.go
@@ -121,6 +121,7 @@ func main() {
|
|||||||
localProxyAddr string
|
localProxyAddr string
|
||||||
localProxyPort int
|
localProxyPort int
|
||||||
localOverridesStr string
|
localOverridesStr string
|
||||||
|
trustedUpstreamsStr string
|
||||||
proxyProtocol bool
|
proxyProtocol bool
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -138,6 +139,7 @@ func main() {
|
|||||||
localProxyAddr = os.Getenv("LOCAL_PROXY")
|
localProxyAddr = os.Getenv("LOCAL_PROXY")
|
||||||
localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT")
|
localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT")
|
||||||
localOverridesStr = os.Getenv("LOCAL_OVERRIDES")
|
localOverridesStr = os.Getenv("LOCAL_OVERRIDES")
|
||||||
|
trustedUpstreamsStr = os.Getenv("TRUSTED_UPSTREAMS")
|
||||||
proxyProtocolStr := os.Getenv("PROXY_PROTOCOL")
|
proxyProtocolStr := os.Getenv("PROXY_PROTOCOL")
|
||||||
|
|
||||||
if interfaceName == "" {
|
if interfaceName == "" {
|
||||||
@@ -197,6 +199,9 @@ func main() {
|
|||||||
if localOverridesStr != "" {
|
if localOverridesStr != "" {
|
||||||
flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy")
|
flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy")
|
||||||
}
|
}
|
||||||
|
if trustedUpstreamsStr == "" {
|
||||||
|
flag.StringVar(&trustedUpstreamsStr, "trusted-upstreams", "", "Comma-separated list of trusted upstream proxy domain names/IPs that can send PROXY protocol")
|
||||||
|
}
|
||||||
|
|
||||||
if proxyProtocolStr != "" {
|
if proxyProtocolStr != "" {
|
||||||
proxyProtocol = strings.ToLower(proxyProtocolStr) == "true"
|
proxyProtocol = strings.ToLower(proxyProtocolStr) == "true"
|
||||||
@@ -323,7 +328,16 @@ func main() {
|
|||||||
logger.Info("Local overrides configured: %v", localOverrides)
|
logger.Info("Local overrides configured: %v", localOverrides)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides, proxyProtocol)
|
var trustedUpstreams []string
|
||||||
|
if trustedUpstreamsStr != "" {
|
||||||
|
trustedUpstreams = strings.Split(trustedUpstreamsStr, ",")
|
||||||
|
for i, upstream := range trustedUpstreams {
|
||||||
|
trustedUpstreams[i] = strings.TrimSpace(upstream)
|
||||||
|
}
|
||||||
|
logger.Info("Trusted upstreams configured: %v", trustedUpstreams)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides, proxyProtocol, trustedUpstreams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to create proxy: %v", err)
|
logger.Fatal("Failed to create proxy: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
259
proxy/proxy.go
259
proxy/proxy.go
@@ -11,6 +11,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -31,6 +32,16 @@ type RouteAPIResponse struct {
|
|||||||
Endpoints []string `json:"endpoints"`
|
Endpoints []string `json:"endpoints"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProxyProtocolInfo holds information parsed from incoming PROXY protocol header
|
||||||
|
type ProxyProtocolInfo struct {
|
||||||
|
Protocol string // TCP4 or TCP6
|
||||||
|
SrcIP string
|
||||||
|
DestIP string
|
||||||
|
SrcPort int
|
||||||
|
DestPort int
|
||||||
|
OriginalConn net.Conn // The original connection after PROXY protocol parsing
|
||||||
|
}
|
||||||
|
|
||||||
// SNIProxy represents the main proxy server
|
// SNIProxy represents the main proxy server
|
||||||
type SNIProxy struct {
|
type SNIProxy struct {
|
||||||
port int
|
port int
|
||||||
@@ -55,6 +66,9 @@ type SNIProxy struct {
|
|||||||
// Track active tunnels by SNI
|
// Track active tunnels by SNI
|
||||||
activeTunnels map[string]*activeTunnel
|
activeTunnels map[string]*activeTunnel
|
||||||
activeTunnelsLock sync.Mutex
|
activeTunnelsLock sync.Mutex
|
||||||
|
|
||||||
|
// Trusted upstream proxies that can send PROXY protocol
|
||||||
|
trustedUpstreams map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type activeTunnel struct {
|
type activeTunnel struct {
|
||||||
@@ -75,6 +89,159 @@ func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
|
|||||||
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
|
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
|
||||||
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
|
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
|
|
||||||
|
// parseProxyProtocolHeader parses a PROXY protocol v1 header from the connection
|
||||||
|
func (p *SNIProxy) parseProxyProtocolHeader(conn net.Conn) (*ProxyProtocolInfo, net.Conn, error) {
|
||||||
|
// Check if the connection comes from a trusted upstream
|
||||||
|
remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, conn, fmt.Errorf("failed to parse remote address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve the remote IP to hostname to check if it's trusted
|
||||||
|
// For simplicity, we'll check the IP directly in trusted upstreams
|
||||||
|
// In production, you might want to do reverse DNS lookup
|
||||||
|
if _, isTrusted := p.trustedUpstreams[remoteHost]; !isTrusted {
|
||||||
|
// Not from trusted upstream, return original connection
|
||||||
|
return nil, conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set read timeout for PROXY protocol parsing
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
|
return nil, conn, fmt.Errorf("failed to set read deadline: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the first line (PROXY protocol header)
|
||||||
|
buffer := make([]byte, 512) // PROXY protocol header should be much smaller
|
||||||
|
n, err := conn.Read(buffer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, conn, fmt.Errorf("failed to read PROXY protocol header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the end of the first line (CRLF)
|
||||||
|
headerEnd := bytes.Index(buffer[:n], []byte("\r\n"))
|
||||||
|
if headerEnd == -1 {
|
||||||
|
return nil, conn, fmt.Errorf("PROXY protocol header not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
headerLine := string(buffer[:headerEnd])
|
||||||
|
remainingData := buffer[headerEnd+2 : n]
|
||||||
|
|
||||||
|
// Parse PROXY protocol line: "PROXY TCP4/TCP6 srcIP destIP srcPort destPort"
|
||||||
|
parts := strings.Fields(headerLine)
|
||||||
|
if len(parts) != 6 || parts[0] != "PROXY" {
|
||||||
|
// Check for PROXY UNKNOWN
|
||||||
|
if len(parts) == 2 && parts[0] == "PROXY" && parts[1] == "UNKNOWN" {
|
||||||
|
// PROXY UNKNOWN - use original connection info
|
||||||
|
return nil, conn, nil
|
||||||
|
}
|
||||||
|
return nil, conn, fmt.Errorf("invalid PROXY protocol header: %s", headerLine)
|
||||||
|
}
|
||||||
|
|
||||||
|
protocol := parts[1]
|
||||||
|
srcIP := parts[2]
|
||||||
|
destIP := parts[3]
|
||||||
|
srcPort, err := strconv.Atoi(parts[4])
|
||||||
|
if err != nil {
|
||||||
|
return nil, conn, fmt.Errorf("invalid source port in PROXY header: %s", parts[4])
|
||||||
|
}
|
||||||
|
destPort, err := strconv.Atoi(parts[5])
|
||||||
|
if err != nil {
|
||||||
|
return nil, conn, fmt.Errorf("invalid destination port in PROXY header: %s", parts[5])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new reader that includes remaining data + original connection
|
||||||
|
var newReader io.Reader
|
||||||
|
if len(remainingData) > 0 {
|
||||||
|
newReader = io.MultiReader(bytes.NewReader(remainingData), conn)
|
||||||
|
} else {
|
||||||
|
newReader = conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a wrapper connection that reads from the combined reader
|
||||||
|
wrappedConn := &proxyProtocolConn{
|
||||||
|
Conn: conn,
|
||||||
|
reader: newReader,
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyInfo := &ProxyProtocolInfo{
|
||||||
|
Protocol: protocol,
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DestIP: destIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DestPort: destPort,
|
||||||
|
OriginalConn: wrappedConn,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear read timeout
|
||||||
|
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||||
|
return nil, conn, fmt.Errorf("failed to clear read deadline: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return proxyInfo, wrappedConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyProtocolConn wraps a connection to read from a custom reader
|
||||||
|
type proxyProtocolConn struct {
|
||||||
|
net.Conn
|
||||||
|
reader io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyProtocolConn) Read(b []byte) (int, error) {
|
||||||
|
return c.reader.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildProxyProtocolHeaderFromInfo creates a PROXY protocol v1 header using ProxyProtocolInfo
|
||||||
|
func (p *SNIProxy) buildProxyProtocolHeaderFromInfo(proxyInfo *ProxyProtocolInfo, targetAddr net.Addr) string {
|
||||||
|
targetTCP, ok := targetAddr.(*net.TCPAddr)
|
||||||
|
if !ok {
|
||||||
|
// Fallback for unknown address types
|
||||||
|
return "PROXY UNKNOWN\r\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the original client information from the PROXY protocol
|
||||||
|
var targetIP string
|
||||||
|
var protocol string
|
||||||
|
|
||||||
|
// Parse source IP to determine protocol family
|
||||||
|
srcIP := net.ParseIP(proxyInfo.SrcIP)
|
||||||
|
if srcIP == nil {
|
||||||
|
return "PROXY UNKNOWN\r\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
if srcIP.To4() != nil {
|
||||||
|
// Source is IPv4, use TCP4 protocol
|
||||||
|
protocol = "TCP4"
|
||||||
|
if targetTCP.IP.To4() != nil {
|
||||||
|
// Target is also IPv4, use as-is
|
||||||
|
targetIP = targetTCP.IP.String()
|
||||||
|
} else {
|
||||||
|
// Target is IPv6, but we need IPv4 for consistent protocol family
|
||||||
|
if targetTCP.IP.IsLoopback() {
|
||||||
|
targetIP = "127.0.0.1"
|
||||||
|
} else {
|
||||||
|
targetIP = "127.0.0.1" // Safe fallback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Source is IPv6, use TCP6 protocol
|
||||||
|
protocol = "TCP6"
|
||||||
|
if targetTCP.IP.To4() != nil {
|
||||||
|
// Target is IPv4, convert to IPv6 representation
|
||||||
|
targetIP = "::ffff:" + targetTCP.IP.String()
|
||||||
|
} else {
|
||||||
|
// Target is also IPv6, use as-is
|
||||||
|
targetIP = targetTCP.IP.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
|
||||||
|
protocol,
|
||||||
|
proxyInfo.SrcIP,
|
||||||
|
targetIP,
|
||||||
|
proxyInfo.SrcPort,
|
||||||
|
targetTCP.Port)
|
||||||
|
}
|
||||||
|
|
||||||
// buildProxyProtocolHeader creates a PROXY protocol v1 header
|
// buildProxyProtocolHeader creates a PROXY protocol v1 header
|
||||||
func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
|
func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
|
||||||
clientTCP, ok := clientAddr.(*net.TCPAddr)
|
clientTCP, ok := clientAddr.(*net.TCPAddr)
|
||||||
@@ -131,7 +298,7 @@ func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewSNIProxy creates a new SNI proxy instance
|
// NewSNIProxy creates a new SNI proxy instance
|
||||||
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool) (*SNIProxy, error) {
|
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool, trustedUpstreams []string) (*SNIProxy, error) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
// Create local overrides map
|
// Create local overrides map
|
||||||
@@ -142,19 +309,36 @@ func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, lo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create trusted upstreams map
|
||||||
|
trustedMap := make(map[string]struct{})
|
||||||
|
for _, upstream := range trustedUpstreams {
|
||||||
|
if upstream != "" {
|
||||||
|
// Add both the domain and potentially resolved IPs
|
||||||
|
trustedMap[upstream] = struct{}{}
|
||||||
|
|
||||||
|
// Try to resolve the domain to IPs and add them too
|
||||||
|
if ips, err := net.LookupIP(upstream); err == nil {
|
||||||
|
for _, ip := range ips {
|
||||||
|
trustedMap[ip.String()] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
proxy := &SNIProxy{
|
proxy := &SNIProxy{
|
||||||
port: port,
|
port: port,
|
||||||
cache: cache.New(3*time.Second, 10*time.Minute),
|
cache: cache.New(3*time.Second, 10*time.Minute),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
localProxyAddr: localProxyAddr,
|
localProxyAddr: localProxyAddr,
|
||||||
localProxyPort: localProxyPort,
|
localProxyPort: localProxyPort,
|
||||||
remoteConfigURL: remoteConfigURL,
|
remoteConfigURL: remoteConfigURL,
|
||||||
publicKey: publicKey,
|
publicKey: publicKey,
|
||||||
proxyProtocol: proxyProtocol,
|
proxyProtocol: proxyProtocol,
|
||||||
localSNIs: make(map[string]struct{}),
|
localSNIs: make(map[string]struct{}),
|
||||||
localOverrides: overridesMap,
|
localOverrides: overridesMap,
|
||||||
activeTunnels: make(map[string]*activeTunnel),
|
activeTunnels: make(map[string]*activeTunnel),
|
||||||
|
trustedUpstreams: trustedMap,
|
||||||
}
|
}
|
||||||
|
|
||||||
return proxy, nil
|
return proxy, nil
|
||||||
@@ -270,14 +454,31 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
|||||||
|
|
||||||
logger.Debug("Accepted connection from %s", clientConn.RemoteAddr())
|
logger.Debug("Accepted connection from %s", clientConn.RemoteAddr())
|
||||||
|
|
||||||
|
// Check for PROXY protocol from trusted upstream
|
||||||
|
var proxyInfo *ProxyProtocolInfo
|
||||||
|
var actualClientConn net.Conn = clientConn
|
||||||
|
|
||||||
|
if len(p.trustedUpstreams) > 0 {
|
||||||
|
var err error
|
||||||
|
proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn)
|
||||||
|
if err != nil {
|
||||||
|
logger.Debug("Failed to parse PROXY protocol: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if proxyInfo != nil {
|
||||||
|
logger.Debug("Received PROXY protocol from trusted upstream: %s:%d -> %s:%d",
|
||||||
|
proxyInfo.SrcIP, proxyInfo.SrcPort, proxyInfo.DestIP, proxyInfo.DestPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Set read timeout for SNI extraction
|
// Set read timeout for SNI extraction
|
||||||
if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
if err := actualClientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
logger.Debug("Failed to set read deadline: %v", err)
|
logger.Debug("Failed to set read deadline: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract SNI hostname
|
// Extract SNI hostname
|
||||||
hostname, clientReader, err := p.extractSNI(clientConn)
|
hostname, clientReader, err := p.extractSNI(actualClientConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debug("SNI extraction failed: %v", err)
|
logger.Debug("SNI extraction failed: %v", err)
|
||||||
return
|
return
|
||||||
@@ -291,13 +492,20 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
|||||||
logger.Debug("SNI hostname detected: %s", hostname)
|
logger.Debug("SNI hostname detected: %s", hostname)
|
||||||
|
|
||||||
// Remove read timeout for normal operation
|
// Remove read timeout for normal operation
|
||||||
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
|
if err := actualClientConn.SetReadDeadline(time.Time{}); err != nil {
|
||||||
logger.Debug("Failed to clear read deadline: %v", err)
|
logger.Debug("Failed to clear read deadline: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get routing information
|
// Get routing information - use original client address if available from PROXY protocol
|
||||||
route, err := p.getRoute(hostname, clientConn.RemoteAddr().String())
|
var clientAddrStr string
|
||||||
|
if proxyInfo != nil {
|
||||||
|
clientAddrStr = fmt.Sprintf("%s:%d", proxyInfo.SrcIP, proxyInfo.SrcPort)
|
||||||
|
} else {
|
||||||
|
clientAddrStr = clientConn.RemoteAddr().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
route, err := p.getRoute(hostname, clientAddrStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debug("Failed to get route for %s: %v", hostname, err)
|
logger.Debug("Failed to get route for %s: %v", hostname, err)
|
||||||
return
|
return
|
||||||
@@ -325,7 +533,14 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
|||||||
|
|
||||||
// Send PROXY protocol header if enabled
|
// Send PROXY protocol header if enabled
|
||||||
if p.proxyProtocol {
|
if p.proxyProtocol {
|
||||||
proxyHeader := buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr())
|
var proxyHeader string
|
||||||
|
if proxyInfo != nil {
|
||||||
|
// Use original client info from PROXY protocol
|
||||||
|
proxyHeader = p.buildProxyProtocolHeaderFromInfo(proxyInfo, targetConn.LocalAddr())
|
||||||
|
} else {
|
||||||
|
// Use direct client connection info
|
||||||
|
proxyHeader = buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr())
|
||||||
|
}
|
||||||
logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader))
|
logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader))
|
||||||
|
|
||||||
if _, err := targetConn.Write([]byte(proxyHeader)); err != nil {
|
if _, err := targetConn.Write([]byte(proxyHeader)); err != nil {
|
||||||
@@ -341,7 +556,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
|||||||
tunnel = &activeTunnel{}
|
tunnel = &activeTunnel{}
|
||||||
p.activeTunnels[hostname] = tunnel
|
p.activeTunnels[hostname] = tunnel
|
||||||
}
|
}
|
||||||
tunnel.conns = append(tunnel.conns, clientConn)
|
tunnel.conns = append(tunnel.conns, actualClientConn)
|
||||||
p.activeTunnelsLock.Unlock()
|
p.activeTunnelsLock.Unlock()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -350,7 +565,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
|||||||
if tunnel, ok := p.activeTunnels[hostname]; ok {
|
if tunnel, ok := p.activeTunnels[hostname]; ok {
|
||||||
newConns := make([]net.Conn, 0, len(tunnel.conns))
|
newConns := make([]net.Conn, 0, len(tunnel.conns))
|
||||||
for _, c := range tunnel.conns {
|
for _, c := range tunnel.conns {
|
||||||
if c != clientConn {
|
if c != actualClientConn {
|
||||||
newConns = append(newConns, c)
|
newConns = append(newConns, c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -364,7 +579,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Start bidirectional data transfer
|
// Start bidirectional data transfer
|
||||||
p.pipe(clientConn, targetConn, clientReader)
|
p.pipe(actualClientConn, targetConn, clientReader)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRoute retrieves routing information for a hostname
|
// getRoute retrieves routing information for a hostname
|
||||||
|
|||||||
@@ -76,3 +76,44 @@ func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) {
|
|||||||
t.Errorf("Expected %q, got %q", expected, result)
|
t.Errorf("Expected %q, got %q", expected, result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) {
|
||||||
|
proxy, err := NewSNIProxy(8443, "", "", "127.0.0.1", 443, nil, true, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create SNI proxy: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test IPv4 case
|
||||||
|
proxyInfo := &ProxyProtocolInfo{
|
||||||
|
Protocol: "TCP4",
|
||||||
|
SrcIP: "10.0.0.1",
|
||||||
|
DestIP: "192.168.1.100",
|
||||||
|
SrcPort: 12345,
|
||||||
|
DestPort: 443,
|
||||||
|
}
|
||||||
|
|
||||||
|
targetAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8080")
|
||||||
|
header := proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr)
|
||||||
|
|
||||||
|
expected := "PROXY TCP4 10.0.0.1 127.0.0.1 12345 8080\r\n"
|
||||||
|
if header != expected {
|
||||||
|
t.Errorf("Expected header '%s', got '%s'", expected, header)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test IPv6 case
|
||||||
|
proxyInfo = &ProxyProtocolInfo{
|
||||||
|
Protocol: "TCP6",
|
||||||
|
SrcIP: "2001:db8::1",
|
||||||
|
DestIP: "2001:db8::2",
|
||||||
|
SrcPort: 12345,
|
||||||
|
DestPort: 443,
|
||||||
|
}
|
||||||
|
|
||||||
|
targetAddr, _ = net.ResolveTCPAddr("tcp6", "[::1]:8080")
|
||||||
|
header = proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr)
|
||||||
|
|
||||||
|
expected = "PROXY TCP6 2001:db8::1 ::1 12345 8080\r\n"
|
||||||
|
if header != expected {
|
||||||
|
t.Errorf("Expected header '%s', got '%s'", expected, header)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user