diff --git a/main.go b/main.go index dde0026..7d02f86 100644 --- a/main.go +++ b/main.go @@ -355,21 +355,8 @@ func main() { return periodicBandwidthCheck(groupCtx, remoteConfigURL+"/gerbil/receive-bandwidth") }) - // Start the UDP proxy server - relayPort := wgconfig.RelayPort - if relayPort == 0 { - relayPort = 21820 // in case there is no relay port set, use 21820 - } - proxyRelay = relay.NewUDPProxyServer(groupCtx, fmt.Sprintf(":%d", relayPort), remoteConfigURL, key, reachableAt) - err = proxyRelay.Start() - if err != nil { - logger.Fatal("Failed to start UDP proxy server: %v", err) - } - defer proxyRelay.Stop() - - // TODO: WE SHOULD PULL THIS OUT OF THE CONFIG OR SOMETHING - // SO YOU DON'T NEED TO SET THIS SEPARATELY - // Parse local overrides + // Parse local overrides and trusted upstreams early so that both the relay + // and the SNI proxy share the same configuration values. var localOverrides []string if localOverridesStr != "" { localOverrides = strings.Split(localOverridesStr, ",") @@ -388,6 +375,21 @@ func main() { logger.Info("Trusted upstreams configured: %v", trustedUpstreams) } + // Start the UDP proxy server. + // proxyProtocol and trustedUpstreams are forwarded so the relay can strip + // PROXY protocol v2 headers from load-balancer traffic and recover the + // original client IP for hole-punch registration. + relayPort := wgconfig.RelayPort + if relayPort == 0 { + relayPort = 21820 // in case there is no relay port set, use 21820 + } + proxyRelay = relay.NewUDPProxyServer(groupCtx, fmt.Sprintf(":%d", relayPort), remoteConfigURL, key, reachableAt, proxyProtocol, trustedUpstreams) + err = proxyRelay.Start() + if err != nil { + logger.Fatal("Failed to start UDP proxy server: %v", err) + } + defer proxyRelay.Stop() + proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides, proxyProtocol, trustedUpstreams) if err != nil { logger.Fatal("Failed to create proxy: %v", err) diff --git a/proxy/proxy.go b/proxy/proxy.go index f29878e..5c67131 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -11,12 +11,12 @@ import ( "log" "net" "net/http" - "strconv" "strings" "sync" "time" "github.com/fosrl/gerbil/logger" + "github.com/fosrl/gerbil/proxyproto" "github.com/patrickmn/go-cache" ) @@ -32,16 +32,6 @@ type RouteAPIResponse struct { Endpoints []string `json:"endpoints"` } -// ProxyProtocolInfo holds information parsed from incoming PROXY protocol header -type ProxyProtocolInfo struct { - Protocol string // TCP4 or TCP6 - SrcIP string - DestIP string - SrcPort int - DestPort int - OriginalConn net.Conn // The original connection after PROXY protocol parsing -} - // SNIProxy represents the main proxy server type SNIProxy struct { port int @@ -89,249 +79,6 @@ func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } -// parseProxyProtocolHeader parses a PROXY protocol v1 header from the connection -func (p *SNIProxy) parseProxyProtocolHeader(conn net.Conn) (*ProxyProtocolInfo, net.Conn, error) { - // Check if the connection comes from a trusted upstream - remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String()) - if err != nil { - return nil, conn, fmt.Errorf("failed to parse remote address: %w", err) - } - - // Resolve the remote IP to hostname to check if it's trusted - // For simplicity, we'll check the IP directly in trusted upstreams - // In production, you might want to do reverse DNS lookup - if _, isTrusted := p.trustedUpstreams[remoteHost]; !isTrusted { - // Not from trusted upstream, return original connection - return nil, conn, nil - } - - // Set read timeout for PROXY protocol parsing - if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { - return nil, conn, fmt.Errorf("failed to set read deadline: %w", err) - } - - // Read the first line (PROXY protocol header) - buffer := make([]byte, 512) // PROXY protocol header should be much smaller - n, err := conn.Read(buffer) - if err != nil { - // If we can't read from trusted upstream, treat as regular connection - logger.Debug("Could not read from trusted upstream %s, treating as regular connection: %v", remoteHost, err) - // Clear read timeout before returning - if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil { - logger.Debug("Failed to clear read deadline: %v", clearErr) - } - return nil, conn, nil - } - - // Find the end of the first line (CRLF) - headerEnd := bytes.Index(buffer[:n], []byte("\r\n")) - if headerEnd == -1 { - // No PROXY protocol header found, treat as regular TLS connection - // Return the connection with the buffered data prepended - logger.Debug("No PROXY protocol header from trusted upstream %s, treating as regular TLS connection", remoteHost) - - // Clear read timeout - if err := conn.SetReadDeadline(time.Time{}); err != nil { - logger.Debug("Failed to clear read deadline: %v", err) - } - - // Create a reader that includes the buffered data + original connection - newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn) - wrappedConn := &proxyProtocolConn{ - Conn: conn, - reader: newReader, - } - return nil, wrappedConn, nil - } - - headerLine := string(buffer[:headerEnd]) - remainingData := buffer[headerEnd+2 : n] - - // Parse PROXY protocol line: "PROXY TCP4/TCP6 srcIP destIP srcPort destPort" - parts := strings.Fields(headerLine) - if len(parts) != 6 || parts[0] != "PROXY" { - // Check for PROXY UNKNOWN - if len(parts) == 2 && parts[0] == "PROXY" && parts[1] == "UNKNOWN" { - // PROXY UNKNOWN - use original connection info - return nil, conn, nil - } - // Invalid PROXY protocol, but might be regular TLS - treat as such - logger.Debug("Invalid PROXY protocol from trusted upstream %s, treating as regular TLS connection: %s", remoteHost, headerLine) - - // Clear read timeout - if err := conn.SetReadDeadline(time.Time{}); err != nil { - logger.Debug("Failed to clear read deadline: %v", err) - } - - // Return the connection with all buffered data prepended - newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn) - wrappedConn := &proxyProtocolConn{ - Conn: conn, - reader: newReader, - } - return nil, wrappedConn, nil - } - - protocol := parts[1] - srcIP := parts[2] - destIP := parts[3] - srcPort, err := strconv.Atoi(parts[4]) - if err != nil { - return nil, conn, fmt.Errorf("invalid source port in PROXY header: %s", parts[4]) - } - destPort, err := strconv.Atoi(parts[5]) - if err != nil { - return nil, conn, fmt.Errorf("invalid destination port in PROXY header: %s", parts[5]) - } - - // Create a new reader that includes remaining data + original connection - var newReader io.Reader - if len(remainingData) > 0 { - newReader = io.MultiReader(bytes.NewReader(remainingData), conn) - } else { - newReader = conn - } - - // Create a wrapper connection that reads from the combined reader - wrappedConn := &proxyProtocolConn{ - Conn: conn, - reader: newReader, - } - - proxyInfo := &ProxyProtocolInfo{ - Protocol: protocol, - SrcIP: srcIP, - DestIP: destIP, - SrcPort: srcPort, - DestPort: destPort, - OriginalConn: wrappedConn, - } - - // Clear read timeout - if err := conn.SetReadDeadline(time.Time{}); err != nil { - return nil, conn, fmt.Errorf("failed to clear read deadline: %w", err) - } - - return proxyInfo, wrappedConn, nil -} - -// proxyProtocolConn wraps a connection to read from a custom reader -type proxyProtocolConn struct { - net.Conn - reader io.Reader -} - -func (c *proxyProtocolConn) Read(b []byte) (int, error) { - return c.reader.Read(b) -} - -// buildProxyProtocolHeaderFromInfo creates a PROXY protocol v1 header using ProxyProtocolInfo -func (p *SNIProxy) buildProxyProtocolHeaderFromInfo(proxyInfo *ProxyProtocolInfo, targetAddr net.Addr) string { - targetTCP, ok := targetAddr.(*net.TCPAddr) - if !ok { - // Fallback for unknown address types - return "PROXY UNKNOWN\r\n" - } - - // Use the original client information from the PROXY protocol - var targetIP string - var protocol string - - // Parse source IP to determine protocol family - srcIP := net.ParseIP(proxyInfo.SrcIP) - if srcIP == nil { - return "PROXY UNKNOWN\r\n" - } - - if srcIP.To4() != nil { - // Source is IPv4, use TCP4 protocol - protocol = "TCP4" - if targetTCP.IP.To4() != nil { - // Target is also IPv4, use as-is - targetIP = targetTCP.IP.String() - } else { - // Target is IPv6, but we need IPv4 for consistent protocol family - if targetTCP.IP.IsLoopback() { - targetIP = "127.0.0.1" - } else { - targetIP = "127.0.0.1" // Safe fallback - } - } - } else { - // Source is IPv6, use TCP6 protocol - protocol = "TCP6" - if targetTCP.IP.To4() != nil { - // Target is IPv4, convert to IPv6 representation - targetIP = "::ffff:" + targetTCP.IP.String() - } else { - // Target is also IPv6, use as-is - targetIP = targetTCP.IP.String() - } - } - - return fmt.Sprintf("PROXY %s %s %s %d %d\r\n", - protocol, - proxyInfo.SrcIP, - targetIP, - proxyInfo.SrcPort, - targetTCP.Port) -} - -// buildProxyProtocolHeader creates a PROXY protocol v1 header -func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string { - clientTCP, ok := clientAddr.(*net.TCPAddr) - if !ok { - // Fallback for unknown address types - return "PROXY UNKNOWN\r\n" - } - - targetTCP, ok := targetAddr.(*net.TCPAddr) - if !ok { - // Fallback for unknown address types - return "PROXY UNKNOWN\r\n" - } - - // Determine protocol family based on client IP and normalize target IP accordingly - var protocol string - var targetIP string - - if clientTCP.IP.To4() != nil { - // Client is IPv4, use TCP4 protocol - protocol = "TCP4" - if targetTCP.IP.To4() != nil { - // Target is also IPv4, use as-is - targetIP = targetTCP.IP.String() - } else { - // Target is IPv6, but we need IPv4 for consistent protocol family - // Use the IPv4 loopback if target is IPv6 loopback, otherwise use 127.0.0.1 - if targetTCP.IP.IsLoopback() { - targetIP = "127.0.0.1" - } else { - // For non-loopback IPv6 targets, we could try to extract embedded IPv4 - // or fall back to a sensible IPv4 address based on the target - targetIP = "127.0.0.1" // Safe fallback - } - } - } else { - // Client is IPv6, use TCP6 protocol - protocol = "TCP6" - if targetTCP.IP.To4() != nil { - // Target is IPv4, convert to IPv6 representation - targetIP = "::ffff:" + targetTCP.IP.String() - } else { - // Target is also IPv6, use as-is - targetIP = targetTCP.IP.String() - } - } - - return fmt.Sprintf("PROXY %s %s %s %d %d\r\n", - protocol, - clientTCP.IP.String(), - targetIP, - clientTCP.Port, - targetTCP.Port) -} - // NewSNIProxy creates a new SNI proxy instance func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool, trustedUpstreams []string) (*SNIProxy, error) { ctx, cancel := context.WithCancel(context.Background()) @@ -490,12 +237,12 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { logger.Debug("Accepted connection from %s", clientConn.RemoteAddr()) // Check for PROXY protocol from trusted upstream - var proxyInfo *ProxyProtocolInfo + var proxyInfo *proxyproto.Info var actualClientConn net.Conn = clientConn if len(p.trustedUpstreams) > 0 { var err error - proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn) + proxyInfo, actualClientConn, err = proxyproto.ParseV1Header(clientConn, p.trustedUpstreams) if err != nil { logger.Debug("Failed to parse PROXY protocol: %v", err) return @@ -575,10 +322,10 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { var proxyHeader string if proxyInfo != nil { // Use original client info from PROXY protocol - proxyHeader = p.buildProxyProtocolHeaderFromInfo(proxyInfo, targetConn.LocalAddr()) + proxyHeader = proxyproto.BuildV1HeaderFromInfo(proxyInfo, targetConn.LocalAddr()) } else { // Use direct client connection info - proxyHeader = buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr()) + proxyHeader = proxyproto.BuildV1Header(clientConn.RemoteAddr(), targetConn.LocalAddr()) } logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader)) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 747c81d..7f50def 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -3,6 +3,8 @@ package proxy import ( "net" "testing" + + "github.com/fosrl/gerbil/proxyproto" ) func TestBuildProxyProtocolHeader(t *testing.T) { @@ -56,7 +58,7 @@ func TestBuildProxyProtocolHeader(t *testing.T) { t.Fatalf("Failed to resolve target address: %v", err) } - result := buildProxyProtocolHeader(clientTCP, targetTCP) + result := proxyproto.BuildV1Header(clientTCP, targetTCP) if result != tt.expected { t.Errorf("Expected %q, got %q", tt.expected, result) } @@ -69,7 +71,7 @@ func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) { clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345} targetAddr := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443} - result := buildProxyProtocolHeader(clientAddr, targetAddr) + result := proxyproto.BuildV1Header(clientAddr, targetAddr) expected := "PROXY UNKNOWN\r\n" if result != expected { @@ -78,13 +80,8 @@ func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) { } func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) { - proxy, err := NewSNIProxy(8443, "", "", "127.0.0.1", 443, nil, true, nil) - if err != nil { - t.Fatalf("Failed to create SNI proxy: %v", err) - } - // Test IPv4 case - proxyInfo := &ProxyProtocolInfo{ + info := &proxyproto.Info{ Protocol: "TCP4", SrcIP: "10.0.0.1", DestIP: "192.168.1.100", @@ -93,7 +90,7 @@ func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) { } targetAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8080") - header := proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr) + header := proxyproto.BuildV1HeaderFromInfo(info, targetAddr) expected := "PROXY TCP4 10.0.0.1 127.0.0.1 12345 8080\r\n" if header != expected { @@ -101,7 +98,7 @@ func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) { } // Test IPv6 case - proxyInfo = &ProxyProtocolInfo{ + info = &proxyproto.Info{ Protocol: "TCP6", SrcIP: "2001:db8::1", DestIP: "2001:db8::2", @@ -110,10 +107,99 @@ func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) { } targetAddr, _ = net.ResolveTCPAddr("tcp6", "[::1]:8080") - header = proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr) + header = proxyproto.BuildV1HeaderFromInfo(info, targetAddr) expected = "PROXY TCP6 2001:db8::1 ::1 12345 8080\r\n" if header != expected { t.Errorf("Expected header '%s', got '%s'", expected, header) } } + +func TestParseV2UDPHeader(t *testing.T) { + // Build a minimal PROXY v2 header for IPv4 UDP + // Magic (12) + ver/cmd (1) + fam/proto (1) + len (2) + src IP (4) + dst IP (4) + src port (2) + dst port (2) = 28 bytes + header := []byte{ + // Magic signature + 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, + // Version 2 (0x2x), PROXY command (0x01) + 0x21, + // AF_INET (0x1x), DGRAM/UDP (0x02) + 0x12, + // Address block length: 12 bytes (4+4+2+2) + 0x00, 0x0C, + // Source IP: 192.168.1.100 + 192, 168, 1, 100, + // Destination IP: 10.0.0.1 + 10, 0, 0, 1, + // Source port: 4500 + 0x11, 0x94, + // Destination port: 21820 + 0x55, 0x3C, + } + + // Append a fake application payload + payload := []byte{0x01, 0x02, 0x03} + data := append(header, payload...) + + info, remaining, ok := proxyproto.ParseV2UDPHeader(data) + if !ok { + t.Fatal("Expected ParseV2UDPHeader to return ok=true") + } + if info == nil { + t.Fatal("Expected non-nil Info") + } + if info.Protocol != "UDP4" { + t.Errorf("Expected protocol UDP4, got %s", info.Protocol) + } + if info.SrcIP != "192.168.1.100" { + t.Errorf("Expected SrcIP 192.168.1.100, got %s", info.SrcIP) + } + if info.DestIP != "10.0.0.1" { + t.Errorf("Expected DestIP 10.0.0.1, got %s", info.DestIP) + } + if info.SrcPort != 4500 { + t.Errorf("Expected SrcPort 4500, got %d", info.SrcPort) + } + if info.DestPort != 21820 { + t.Errorf("Expected DestPort 21820, got %d", info.DestPort) + } + if len(remaining) != len(payload) { + t.Errorf("Expected %d remaining bytes, got %d", len(payload), len(remaining)) + } +} + +func TestParseV2UDPHeaderNoHeader(t *testing.T) { + // Data that does NOT start with v2 magic should be returned as-is + data := []byte{0x01, 0x02, 0x03} + info, remaining, ok := proxyproto.ParseV2UDPHeader(data) + if ok { + t.Error("Expected ok=false for non-v2 data") + } + if info != nil { + t.Error("Expected nil Info for non-v2 data") + } + if len(remaining) != len(data) { + t.Errorf("Expected remaining to equal original data length %d, got %d", len(data), len(remaining)) + } +} + +func TestIsV2Header(t *testing.T) { + valid := []byte{ + 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, + // extra bytes beyond the magic + 0x21, 0x12, + } + if !proxyproto.IsV2Header(valid) { + t.Error("Expected IsV2Header=true for valid magic") + } + + invalid := []byte{0x01, 0x02, 0x03} + if proxyproto.IsV2Header(invalid) { + t.Error("Expected IsV2Header=false for non-magic data") + } + + tooShort := []byte{0x0D, 0x0A} + if proxyproto.IsV2Header(tooShort) { + t.Error("Expected IsV2Header=false for too-short data") + } +} \ No newline at end of file diff --git a/proxyproto/proxyproto.go b/proxyproto/proxyproto.go new file mode 100644 index 0000000..cc13ce2 --- /dev/null +++ b/proxyproto/proxyproto.go @@ -0,0 +1,370 @@ +// Package proxyproto provides shared PROXY protocol v1 (TCP) and v2 (UDP) parsing +// and header building utilities used by both the SNI proxy and UDP relay components. +package proxyproto + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" + + "github.com/fosrl/gerbil/logger" +) + +// v2Signature is the 12-byte magic prefix for PROXY protocol v2 headers. +var v2Signature = []byte{ + 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, +} + +// Info holds information parsed from an incoming PROXY protocol header (v1 or v2). +type Info struct { + Protocol string // e.g. "TCP4", "TCP6", "UDP4", "UDP6" + SrcIP string + DestIP string + SrcPort int + DestPort int +} + +// Conn wraps a net.Conn so that reads are satisfied from a pre-pended buffered +// reader first (remaining bytes after PROXY header parsing) and then from the +// underlying connection. All other net.Conn methods are forwarded unchanged. +type Conn struct { + net.Conn + Reader io.Reader +} + +// Read satisfies net.Conn, draining the buffered reader before falling through +// to the underlying connection. +func (c *Conn) Read(b []byte) (int, error) { + return c.Reader.Read(b) +} + +// IsV2Header returns true when data begins with the 12-byte PROXY protocol v2 +// magic signature. +func IsV2Header(data []byte) bool { + if len(data) < 12 { + return false + } + return bytes.Equal(data[:12], v2Signature) +} + +// ParseV2UDPHeader tries to parse a PROXY protocol v2 header from the front of +// a UDP datagram payload. +// +// Three return values are provided: +// - *Info – filled when a PROXY command header was parsed successfully; nil +// for a LOCAL command or unrecognised address family. +// - []byte – the remaining payload that follows the header (the actual +// application data). +// - bool – true when a v2 header was detected (and consumed), false when +// no v2 magic is present and data should be treated as-is. +func ParseV2UDPHeader(data []byte) (*Info, []byte, bool) { + if !IsV2Header(data) { + return nil, data, false + } + + // Minimum fixed header size: 12 (magic) + 1 (ver/cmd) + 1 (fam/proto) + 2 (len) = 16 + if len(data) < 16 { + return nil, data, false + } + + // Byte 12: version (high nibble) + command (low nibble) + versionCmd := data[12] + version := (versionCmd >> 4) & 0x0F + command := versionCmd & 0x0F + + if version != 2 { + return nil, data, false + } + + // Byte 13: address family (high nibble) + transport protocol (low nibble) + familyProto := data[13] + family := (familyProto >> 4) & 0x0F + protocol := familyProto & 0x0F + + // Bytes 14-15: length of the address block that follows, big-endian + addrLen := int(binary.BigEndian.Uint16(data[14:16])) + totalHeaderLen := 16 + addrLen + + if len(data) < totalHeaderLen { + // Truncated packet – signal that a header was detected but is malformed + return nil, data, false + } + + payload := data[totalHeaderLen:] + + // LOCAL command (0) carries no address information. + if command == 0 { + return nil, payload, true + } + + if command != 1 { + // Unknown command – consume the header and return no info + return nil, payload, true + } + + addrBlock := data[16:totalHeaderLen] + + var ( + srcIP, destIP net.IP + srcPort uint16 + destPort uint16 + protocolStr string + ) + + switch { + case family == 1 && protocol == 1: // AF_INET / STREAM (TCP over IPv4) + if len(addrBlock) < 12 { + return nil, payload, false + } + srcIP = net.IP(addrBlock[0:4]) + destIP = net.IP(addrBlock[4:8]) + srcPort = binary.BigEndian.Uint16(addrBlock[8:10]) + destPort = binary.BigEndian.Uint16(addrBlock[10:12]) + protocolStr = "TCP4" + + case family == 1 && protocol == 2: // AF_INET / DGRAM (UDP over IPv4) + if len(addrBlock) < 12 { + return nil, payload, false + } + srcIP = net.IP(addrBlock[0:4]) + destIP = net.IP(addrBlock[4:8]) + srcPort = binary.BigEndian.Uint16(addrBlock[8:10]) + destPort = binary.BigEndian.Uint16(addrBlock[10:12]) + protocolStr = "UDP4" + + case family == 2 && protocol == 1: // AF_INET6 / STREAM (TCP over IPv6) + if len(addrBlock) < 36 { + return nil, payload, false + } + srcIP = net.IP(addrBlock[0:16]) + destIP = net.IP(addrBlock[16:32]) + srcPort = binary.BigEndian.Uint16(addrBlock[32:34]) + destPort = binary.BigEndian.Uint16(addrBlock[34:36]) + protocolStr = "TCP6" + + case family == 2 && protocol == 2: // AF_INET6 / DGRAM (UDP over IPv6) + if len(addrBlock) < 36 { + return nil, payload, false + } + srcIP = net.IP(addrBlock[0:16]) + destIP = net.IP(addrBlock[16:32]) + srcPort = binary.BigEndian.Uint16(addrBlock[32:34]) + destPort = binary.BigEndian.Uint16(addrBlock[34:36]) + protocolStr = "UDP6" + + default: + // UNSPEC or AF_UNIX – consume the header, no address info available + return nil, payload, true + } + + info := &Info{ + Protocol: protocolStr, + SrcIP: srcIP.String(), + DestIP: destIP.String(), + SrcPort: int(srcPort), + DestPort: int(destPort), + } + return info, payload, true +} + +// ParseV1Header attempts to parse a PROXY protocol v1 (text) header from the +// given TCP connection. +// +// The function first checks whether the remote address appears in +// trustedUpstreams. If it does not, it returns (nil, conn, nil) and the caller +// should treat the connection as a plain (non-proxied) connection. +// +// When a trusted upstream is detected the function reads up to 512 bytes, +// locates the CRLF-terminated header line, and parses the proxy information. +// Whatever bytes were consumed (including any data beyond the header line) are +// re-prepended via a *Conn wrapper so that subsequent reads by the caller are +// transparent. +// +// Return values: +// - *Info – non-nil when a valid PROXY header was parsed. +// - net.Conn – always a valid connection (possibly a *Conn wrapper). +// - error – non-nil only on hard failures (e.g. bad port numbers). +func ParseV1Header(conn net.Conn, trustedUpstreams map[string]struct{}) (*Info, net.Conn, error) { + remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + return nil, conn, fmt.Errorf("failed to parse remote address: %w", err) + } + + if _, isTrusted := trustedUpstreams[remoteHost]; !isTrusted { + return nil, conn, nil + } + + // Give the upstream 5 s to deliver the PROXY header before timing out. + if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + return nil, conn, fmt.Errorf("failed to set read deadline: %w", err) + } + + // The PROXY v1 spec mandates the header fits in 108 bytes; 512 is generous. + buffer := make([]byte, 512) + n, err := conn.Read(buffer) + if err != nil { + logger.Debug("Could not read from trusted upstream %s, treating as regular connection: %v", remoteHost, err) + if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil { + logger.Debug("Failed to clear read deadline: %v", clearErr) + } + return nil, conn, nil + } + + // Locate the CRLF that terminates the PROXY header line. + headerEnd := bytes.Index(buffer[:n], []byte("\r\n")) + if headerEnd == -1 { + logger.Debug("No PROXY protocol header from trusted upstream %s, treating as regular TLS connection", remoteHost) + if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil { + logger.Debug("Failed to clear read deadline: %v", clearErr) + } + newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn) + return nil, &Conn{Conn: conn, Reader: newReader}, nil + } + + headerLine := string(buffer[:headerEnd]) + remainingData := buffer[headerEnd+2 : n] + + parts := strings.Fields(headerLine) + + // Handle "PROXY UNKNOWN" – upstream knows the real source but we don't need it. + if len(parts) == 2 && parts[0] == "PROXY" && parts[1] == "UNKNOWN" { + if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil { + logger.Debug("Failed to clear read deadline: %v", clearErr) + } + var newConn net.Conn + if len(remainingData) > 0 { + newConn = &Conn{Conn: conn, Reader: io.MultiReader(bytes.NewReader(remainingData), conn)} + } else { + newConn = conn + } + return nil, newConn, nil + } + + if len(parts) != 6 || parts[0] != "PROXY" { + // Malformed line from a trusted upstream – re-prepend everything and + // let the caller deal with it as a plain TLS connection. + logger.Debug("Invalid PROXY protocol from trusted upstream %s, treating as regular TLS connection: %s", remoteHost, headerLine) + if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil { + logger.Debug("Failed to clear read deadline: %v", clearErr) + } + newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn) + return nil, &Conn{Conn: conn, Reader: newReader}, nil + } + + protocol := parts[1] + srcIP := parts[2] + destIP := parts[3] + + srcPort, err := strconv.Atoi(parts[4]) + if err != nil { + return nil, conn, fmt.Errorf("invalid source port in PROXY header: %s", parts[4]) + } + destPort, err := strconv.Atoi(parts[5]) + if err != nil { + return nil, conn, fmt.Errorf("invalid destination port in PROXY header: %s", parts[5]) + } + + // Re-assemble a reader that returns any bytes read beyond the header first. + var newReader io.Reader + if len(remainingData) > 0 { + newReader = io.MultiReader(bytes.NewReader(remainingData), conn) + } else { + newReader = conn + } + wrappedConn := &Conn{Conn: conn, Reader: newReader} + + if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil { + return nil, conn, fmt.Errorf("failed to clear read deadline: %w", clearErr) + } + + info := &Info{ + Protocol: protocol, + SrcIP: srcIP, + DestIP: destIP, + SrcPort: srcPort, + DestPort: destPort, + } + return info, wrappedConn, nil +} + +// BuildV1Header constructs a PROXY protocol v1 header string from two TCP +// addresses, normalising the protocol family so that v1's constraint of a +// single family per header is satisfied. +func BuildV1Header(clientAddr, targetAddr net.Addr) string { + clientTCP, ok := clientAddr.(*net.TCPAddr) + if !ok { + return "PROXY UNKNOWN\r\n" + } + targetTCP, ok := targetAddr.(*net.TCPAddr) + if !ok { + return "PROXY UNKNOWN\r\n" + } + + var protocol, targetIP string + + if clientTCP.IP.To4() != nil { + // IPv4 client + protocol = "TCP4" + if targetTCP.IP.To4() != nil { + targetIP = targetTCP.IP.String() + } else if targetTCP.IP.IsLoopback() { + targetIP = "127.0.0.1" + } else { + targetIP = "127.0.0.1" // safe fallback for mixed-family + } + } else { + // IPv6 client + protocol = "TCP6" + if targetTCP.IP.To4() != nil { + targetIP = "::ffff:" + targetTCP.IP.String() + } else { + targetIP = targetTCP.IP.String() + } + } + + return fmt.Sprintf("PROXY %s %s %s %d %d\r\n", + protocol, clientTCP.IP.String(), targetIP, clientTCP.Port, targetTCP.Port) +} + +// BuildV1HeaderFromInfo constructs a PROXY protocol v1 header string using a +// previously-parsed *Info (i.e. when this server itself sits behind an +// upstream proxy) and the target TCP address. +func BuildV1HeaderFromInfo(info *Info, targetAddr net.Addr) string { + targetTCP, ok := targetAddr.(*net.TCPAddr) + if !ok { + return "PROXY UNKNOWN\r\n" + } + + srcIP := net.ParseIP(info.SrcIP) + if srcIP == nil { + return "PROXY UNKNOWN\r\n" + } + + var protocol, targetIP string + + if srcIP.To4() != nil { + protocol = "TCP4" + if targetTCP.IP.To4() != nil { + targetIP = targetTCP.IP.String() + } else if targetTCP.IP.IsLoopback() { + targetIP = "127.0.0.1" + } else { + targetIP = "127.0.0.1" + } + } else { + protocol = "TCP6" + if targetTCP.IP.To4() != nil { + targetIP = "::ffff:" + targetTCP.IP.String() + } else { + targetIP = targetTCP.IP.String() + } + } + + return fmt.Sprintf("PROXY %s %s %s %d %d\r\n", + protocol, info.SrcIP, targetIP, info.SrcPort, targetTCP.Port) +} \ No newline at end of file diff --git a/relay/relay.go b/relay/relay.go index 0ab5930..f9edfdd 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -10,10 +10,12 @@ import ( "net" "net/http" "runtime" + "strings" "sync" "time" "github.com/fosrl/gerbil/logger" + "github.com/fosrl/gerbil/proxyproto" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/curve25519" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -168,19 +170,50 @@ type UDPProxyServer struct { addrCache sync.Map // ReachableAt is the URL where this server can be reached ReachableAt string + + // proxyProtocol enables PROXY protocol v2 header parsing for incoming UDP packets. + // When enabled, packets from trustedUpstreams that carry a v2 header will have + // their source address overridden with the address reported in the header. + proxyProtocol bool + trustedUpstreams map[string]struct{} } // NewUDPProxyServer initializes the server with a buffered packet channel and derived context. -func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer { +// +// proxyProtocol enables PROXY protocol v2 parsing for datagrams arriving from +// any address listed in trustedUpstreams (plain IPs or resolvable hostnames). +// When a trusted datagram carries a v2 header its source address is replaced +// with the address carried inside the header before further processing, so that +// hole-punch endpoints reflect the original client IP rather than the load +// balancer's address. +func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string, proxyProtocol bool, trustedUpstreams []string) *UDPProxyServer { ctx, cancel := context.WithCancel(parentCtx) + + trustedMap := make(map[string]struct{}) + for _, upstream := range trustedUpstreams { + upstream = strings.TrimSpace(upstream) + if upstream == "" { + continue + } + trustedMap[upstream] = struct{}{} + // Also resolve any hostnames to their current IPs so we can match by IP. + if ips, err := net.LookupIP(upstream); err == nil { + for _, ip := range ips { + trustedMap[ip.String()] = struct{}{} + } + } + } + return &UDPProxyServer{ - addr: addr, - serverURL: serverURL, - privateKey: privateKey, - packetChan: make(chan Packet, 50000), // Increased from 1000 to handle high throughput - ReachableAt: reachableAt, - ctx: ctx, - cancel: cancel, + addr: addr, + serverURL: serverURL, + privateKey: privateKey, + packetChan: make(chan Packet, 50000), // Increased from 1000 to handle high throughput + ReachableAt: reachableAt, + ctx: ctx, + cancel: cancel, + proxyProtocol: proxyProtocol, + trustedUpstreams: trustedMap, } } @@ -288,13 +321,48 @@ func (s *UDPProxyServer) readPackets() { // packetWorker processes incoming packets from the channel. func (s *UDPProxyServer) packetWorker() { for packet := range s.packetChan { - // Determine packet type by inspecting the first byte. - if packet.n > 0 && packet.data[0] >= 1 && packet.data[0] <= 4 { + // effectiveData and effectiveAddr represent the application-layer payload + // and the true originating address. They start as the raw UDP values and + // may be updated below when a PROXY protocol v2 header is present. + effectiveData := packet.data[:packet.n] + effectiveAddr := packet.remoteAddr + + // ---------- PROXY protocol v2 (UDP) ------------------------------------ + // If proxy protocol is enabled and this datagram arrives from a trusted + // upstream (e.g. a load balancer), attempt to parse the v2 header so + // that we use the original client address for hole-punch registration and + // WireGuard session tracking rather than the load balancer's address. + if s.proxyProtocol && len(s.trustedUpstreams) > 0 { + remoteHost := packet.remoteAddr.IP.String() + if _, trusted := s.trustedUpstreams[remoteHost]; trusted { + if info, payload, ok := proxyproto.ParseV2UDPHeader(effectiveData); ok { + if info != nil { + // Override source address with what the proxy reported. + if srcIP := net.ParseIP(info.SrcIP); srcIP != nil { + effectiveAddr = &net.UDPAddr{ + IP: srcIP, + Port: info.SrcPort, + } + logger.Debug("PROXY protocol v2: overriding source %s → %s:%d", + packet.remoteAddr, info.SrcIP, info.SrcPort) + } + } + // Always advance past the header so the remainder is treated + // as the real application payload. + effectiveData = payload + } + } + } + // ----------------------------------------------------------------------- + + // Determine packet type by inspecting the first byte of the (possibly + // stripped) application payload. + if len(effectiveData) > 0 && effectiveData[0] >= 1 && effectiveData[0] <= 4 { // Process as a WireGuard packet. - s.handleWireGuardPacket(packet.data, packet.remoteAddr) + s.handleWireGuardPacket(effectiveData, effectiveAddr) } else { // Rate limit: allow at most 2 hole punch messages per IP:Port per second - rateLimitKey := packet.remoteAddr.String() + rateLimitKey := effectiveAddr.String() entryVal, _ := s.holePunchRateLimiter.LoadOrStore(rateLimitKey, &holePunchRateLimitEntry{ windowStart: time.Now(), }) @@ -316,7 +384,7 @@ func (s *UDPProxyServer) packetWorker() { // Process as an encrypted hole punch message var encMsg EncryptedHolePunchMessage - if err := json.Unmarshal(packet.data, &encMsg); err != nil { + if err := json.Unmarshal(effectiveData, &encMsg); err != nil { logger.Error("Error unmarshaling encrypted message: %v", err) // Return the buffer to the pool for reuse and continue with next packet bufferPool.Put(packet.data[:1500]) @@ -352,14 +420,14 @@ func (s *UDPProxyServer) packetWorker() { NewtID: msg.NewtID, OlmID: msg.OlmID, Token: msg.Token, - IP: packet.remoteAddr.IP.String(), - Port: packet.remoteAddr.Port, + IP: effectiveAddr.IP.String(), + Port: effectiveAddr.Port, Timestamp: time.Now().Unix(), ReachableAt: s.ReachableAt, ExitNodePublicKey: s.privateKey.PublicKey().String(), ClientPublicKey: msg.PublicKey, } - logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", packet.remoteAddr.String(), endpoint.IP, endpoint.Port) + logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", effectiveAddr.String(), endpoint.IP, endpoint.Port) s.notifyServer(endpoint) s.clearSessionsForIP(endpoint.IP) // Clear sessions for this IP to allow re-establishment }