diff --git a/main.go b/main.go index 344433f..cb4a9ac 100644 --- a/main.go +++ b/main.go @@ -121,6 +121,7 @@ func main() { localProxyAddr string localProxyPort int localOverridesStr string + trustedUpstreamsStr string proxyProtocol bool ) @@ -138,6 +139,7 @@ func main() { localProxyAddr = os.Getenv("LOCAL_PROXY") localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT") localOverridesStr = os.Getenv("LOCAL_OVERRIDES") + trustedUpstreamsStr = os.Getenv("TRUSTED_UPSTREAMS") proxyProtocolStr := os.Getenv("PROXY_PROTOCOL") if interfaceName == "" { @@ -197,6 +199,9 @@ func main() { if localOverridesStr != "" { flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy") } + if trustedUpstreamsStr == "" { + flag.StringVar(&trustedUpstreamsStr, "trusted-upstreams", "", "Comma-separated list of trusted upstream proxy domain names/IPs that can send PROXY protocol") + } if proxyProtocolStr != "" { proxyProtocol = strings.ToLower(proxyProtocolStr) == "true" @@ -323,7 +328,16 @@ func main() { logger.Info("Local overrides configured: %v", localOverrides) } - proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides, proxyProtocol) + var trustedUpstreams []string + if trustedUpstreamsStr != "" { + trustedUpstreams = strings.Split(trustedUpstreamsStr, ",") + for i, upstream := range trustedUpstreams { + trustedUpstreams[i] = strings.TrimSpace(upstream) + } + logger.Info("Trusted upstreams configured: %v", trustedUpstreams) + } + + proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides, proxyProtocol, trustedUpstreams) if err != nil { logger.Fatal("Failed to create proxy: %v", err) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 7c7b078..e2e0c73 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -11,6 +11,7 @@ import ( "log" "net" "net/http" + "strconv" "strings" "sync" "time" @@ -31,6 +32,16 @@ type RouteAPIResponse struct { Endpoints []string `json:"endpoints"` } +// ProxyProtocolInfo holds information parsed from incoming PROXY protocol header +type ProxyProtocolInfo struct { + Protocol string // TCP4 or TCP6 + SrcIP string + DestIP string + SrcPort int + DestPort int + OriginalConn net.Conn // The original connection after PROXY protocol parsing +} + // SNIProxy represents the main proxy server type SNIProxy struct { port int @@ -55,6 +66,9 @@ type SNIProxy struct { // Track active tunnels by SNI activeTunnels map[string]*activeTunnel activeTunnelsLock sync.Mutex + + // Trusted upstream proxies that can send PROXY protocol + trustedUpstreams map[string]struct{} } type activeTunnel struct { @@ -75,6 +89,159 @@ func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } +// parseProxyProtocolHeader parses a PROXY protocol v1 header from the connection +func (p *SNIProxy) parseProxyProtocolHeader(conn net.Conn) (*ProxyProtocolInfo, net.Conn, error) { + // Check if the connection comes from a trusted upstream + remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + return nil, conn, fmt.Errorf("failed to parse remote address: %w", err) + } + + // Resolve the remote IP to hostname to check if it's trusted + // For simplicity, we'll check the IP directly in trusted upstreams + // In production, you might want to do reverse DNS lookup + if _, isTrusted := p.trustedUpstreams[remoteHost]; !isTrusted { + // Not from trusted upstream, return original connection + return nil, conn, nil + } + + // Set read timeout for PROXY protocol parsing + if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + return nil, conn, fmt.Errorf("failed to set read deadline: %w", err) + } + + // Read the first line (PROXY protocol header) + buffer := make([]byte, 512) // PROXY protocol header should be much smaller + n, err := conn.Read(buffer) + if err != nil { + return nil, conn, fmt.Errorf("failed to read PROXY protocol header: %w", err) + } + + // Find the end of the first line (CRLF) + headerEnd := bytes.Index(buffer[:n], []byte("\r\n")) + if headerEnd == -1 { + return nil, conn, fmt.Errorf("PROXY protocol header not found") + } + + headerLine := string(buffer[:headerEnd]) + remainingData := buffer[headerEnd+2 : n] + + // Parse PROXY protocol line: "PROXY TCP4/TCP6 srcIP destIP srcPort destPort" + parts := strings.Fields(headerLine) + if len(parts) != 6 || parts[0] != "PROXY" { + // Check for PROXY UNKNOWN + if len(parts) == 2 && parts[0] == "PROXY" && parts[1] == "UNKNOWN" { + // PROXY UNKNOWN - use original connection info + return nil, conn, nil + } + return nil, conn, fmt.Errorf("invalid PROXY protocol header: %s", headerLine) + } + + protocol := parts[1] + srcIP := parts[2] + destIP := parts[3] + srcPort, err := strconv.Atoi(parts[4]) + if err != nil { + return nil, conn, fmt.Errorf("invalid source port in PROXY header: %s", parts[4]) + } + destPort, err := strconv.Atoi(parts[5]) + if err != nil { + return nil, conn, fmt.Errorf("invalid destination port in PROXY header: %s", parts[5]) + } + + // Create a new reader that includes remaining data + original connection + var newReader io.Reader + if len(remainingData) > 0 { + newReader = io.MultiReader(bytes.NewReader(remainingData), conn) + } else { + newReader = conn + } + + // Create a wrapper connection that reads from the combined reader + wrappedConn := &proxyProtocolConn{ + Conn: conn, + reader: newReader, + } + + proxyInfo := &ProxyProtocolInfo{ + Protocol: protocol, + SrcIP: srcIP, + DestIP: destIP, + SrcPort: srcPort, + DestPort: destPort, + OriginalConn: wrappedConn, + } + + // Clear read timeout + if err := conn.SetReadDeadline(time.Time{}); err != nil { + return nil, conn, fmt.Errorf("failed to clear read deadline: %w", err) + } + + return proxyInfo, wrappedConn, nil +} + +// proxyProtocolConn wraps a connection to read from a custom reader +type proxyProtocolConn struct { + net.Conn + reader io.Reader +} + +func (c *proxyProtocolConn) Read(b []byte) (int, error) { + return c.reader.Read(b) +} + +// buildProxyProtocolHeaderFromInfo creates a PROXY protocol v1 header using ProxyProtocolInfo +func (p *SNIProxy) buildProxyProtocolHeaderFromInfo(proxyInfo *ProxyProtocolInfo, targetAddr net.Addr) string { + targetTCP, ok := targetAddr.(*net.TCPAddr) + if !ok { + // Fallback for unknown address types + return "PROXY UNKNOWN\r\n" + } + + // Use the original client information from the PROXY protocol + var targetIP string + var protocol string + + // Parse source IP to determine protocol family + srcIP := net.ParseIP(proxyInfo.SrcIP) + if srcIP == nil { + return "PROXY UNKNOWN\r\n" + } + + if srcIP.To4() != nil { + // Source is IPv4, use TCP4 protocol + protocol = "TCP4" + if targetTCP.IP.To4() != nil { + // Target is also IPv4, use as-is + targetIP = targetTCP.IP.String() + } else { + // Target is IPv6, but we need IPv4 for consistent protocol family + if targetTCP.IP.IsLoopback() { + targetIP = "127.0.0.1" + } else { + targetIP = "127.0.0.1" // Safe fallback + } + } + } else { + // Source is IPv6, use TCP6 protocol + protocol = "TCP6" + if targetTCP.IP.To4() != nil { + // Target is IPv4, convert to IPv6 representation + targetIP = "::ffff:" + targetTCP.IP.String() + } else { + // Target is also IPv6, use as-is + targetIP = targetTCP.IP.String() + } + } + + return fmt.Sprintf("PROXY %s %s %s %d %d\r\n", + protocol, + proxyInfo.SrcIP, + targetIP, + proxyInfo.SrcPort, + targetTCP.Port) +} + // buildProxyProtocolHeader creates a PROXY protocol v1 header func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string { clientTCP, ok := clientAddr.(*net.TCPAddr) @@ -131,7 +298,7 @@ func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string { } // NewSNIProxy creates a new SNI proxy instance -func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool) (*SNIProxy, error) { +func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool, trustedUpstreams []string) (*SNIProxy, error) { ctx, cancel := context.WithCancel(context.Background()) // Create local overrides map @@ -142,19 +309,36 @@ func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, lo } } + // Create trusted upstreams map + trustedMap := make(map[string]struct{}) + for _, upstream := range trustedUpstreams { + if upstream != "" { + // Add both the domain and potentially resolved IPs + trustedMap[upstream] = struct{}{} + + // Try to resolve the domain to IPs and add them too + if ips, err := net.LookupIP(upstream); err == nil { + for _, ip := range ips { + trustedMap[ip.String()] = struct{}{} + } + } + } + } + proxy := &SNIProxy{ - port: port, - cache: cache.New(3*time.Second, 10*time.Minute), - ctx: ctx, - cancel: cancel, - localProxyAddr: localProxyAddr, - localProxyPort: localProxyPort, - remoteConfigURL: remoteConfigURL, - publicKey: publicKey, - proxyProtocol: proxyProtocol, - localSNIs: make(map[string]struct{}), - localOverrides: overridesMap, - activeTunnels: make(map[string]*activeTunnel), + port: port, + cache: cache.New(3*time.Second, 10*time.Minute), + ctx: ctx, + cancel: cancel, + localProxyAddr: localProxyAddr, + localProxyPort: localProxyPort, + remoteConfigURL: remoteConfigURL, + publicKey: publicKey, + proxyProtocol: proxyProtocol, + localSNIs: make(map[string]struct{}), + localOverrides: overridesMap, + activeTunnels: make(map[string]*activeTunnel), + trustedUpstreams: trustedMap, } return proxy, nil @@ -270,14 +454,31 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { logger.Debug("Accepted connection from %s", clientConn.RemoteAddr()) + // Check for PROXY protocol from trusted upstream + var proxyInfo *ProxyProtocolInfo + var actualClientConn net.Conn = clientConn + + if len(p.trustedUpstreams) > 0 { + var err error + proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn) + if err != nil { + logger.Debug("Failed to parse PROXY protocol: %v", err) + return + } + if proxyInfo != nil { + logger.Debug("Received PROXY protocol from trusted upstream: %s:%d -> %s:%d", + proxyInfo.SrcIP, proxyInfo.SrcPort, proxyInfo.DestIP, proxyInfo.DestPort) + } + } + // Set read timeout for SNI extraction - if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + if err := actualClientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { logger.Debug("Failed to set read deadline: %v", err) return } // Extract SNI hostname - hostname, clientReader, err := p.extractSNI(clientConn) + hostname, clientReader, err := p.extractSNI(actualClientConn) if err != nil { logger.Debug("SNI extraction failed: %v", err) return @@ -291,13 +492,20 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { logger.Debug("SNI hostname detected: %s", hostname) // Remove read timeout for normal operation - if err := clientConn.SetReadDeadline(time.Time{}); err != nil { + if err := actualClientConn.SetReadDeadline(time.Time{}); err != nil { logger.Debug("Failed to clear read deadline: %v", err) return } - // Get routing information - route, err := p.getRoute(hostname, clientConn.RemoteAddr().String()) + // Get routing information - use original client address if available from PROXY protocol + var clientAddrStr string + if proxyInfo != nil { + clientAddrStr = fmt.Sprintf("%s:%d", proxyInfo.SrcIP, proxyInfo.SrcPort) + } else { + clientAddrStr = clientConn.RemoteAddr().String() + } + + route, err := p.getRoute(hostname, clientAddrStr) if err != nil { logger.Debug("Failed to get route for %s: %v", hostname, err) return @@ -325,7 +533,14 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { // Send PROXY protocol header if enabled if p.proxyProtocol { - proxyHeader := buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr()) + var proxyHeader string + if proxyInfo != nil { + // Use original client info from PROXY protocol + proxyHeader = p.buildProxyProtocolHeaderFromInfo(proxyInfo, targetConn.LocalAddr()) + } else { + // Use direct client connection info + proxyHeader = buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr()) + } logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader)) if _, err := targetConn.Write([]byte(proxyHeader)); err != nil { @@ -341,7 +556,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { tunnel = &activeTunnel{} p.activeTunnels[hostname] = tunnel } - tunnel.conns = append(tunnel.conns, clientConn) + tunnel.conns = append(tunnel.conns, actualClientConn) p.activeTunnelsLock.Unlock() defer func() { @@ -350,7 +565,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { if tunnel, ok := p.activeTunnels[hostname]; ok { newConns := make([]net.Conn, 0, len(tunnel.conns)) for _, c := range tunnel.conns { - if c != clientConn { + if c != actualClientConn { newConns = append(newConns, c) } } @@ -364,7 +579,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { }() // Start bidirectional data transfer - p.pipe(clientConn, targetConn, clientReader) + p.pipe(actualClientConn, targetConn, clientReader) } // getRoute retrieves routing information for a hostname diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index d585610..747c81d 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -76,3 +76,44 @@ func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) { t.Errorf("Expected %q, got %q", expected, result) } } + +func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) { + proxy, err := NewSNIProxy(8443, "", "", "127.0.0.1", 443, nil, true, nil) + if err != nil { + t.Fatalf("Failed to create SNI proxy: %v", err) + } + + // Test IPv4 case + proxyInfo := &ProxyProtocolInfo{ + Protocol: "TCP4", + SrcIP: "10.0.0.1", + DestIP: "192.168.1.100", + SrcPort: 12345, + DestPort: 443, + } + + targetAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8080") + header := proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr) + + expected := "PROXY TCP4 10.0.0.1 127.0.0.1 12345 8080\r\n" + if header != expected { + t.Errorf("Expected header '%s', got '%s'", expected, header) + } + + // Test IPv6 case + proxyInfo = &ProxyProtocolInfo{ + Protocol: "TCP6", + SrcIP: "2001:db8::1", + DestIP: "2001:db8::2", + SrcPort: 12345, + DestPort: 443, + } + + targetAddr, _ = net.ResolveTCPAddr("tcp6", "[::1]:8080") + header = proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr) + + expected = "PROXY TCP6 2001:db8::1 ::1 12345 8080\r\n" + if header != expected { + t.Errorf("Expected header '%s', got '%s'", expected, header) + } +}