mirror of
https://github.com/fosrl/gerbil.git
synced 2026-05-15 20:59:56 +00:00
Add pp to relay
This commit is contained in:
370
proxyproto/proxyproto.go
Normal file
370
proxyproto/proxyproto.go
Normal file
@@ -0,0 +1,370 @@
|
||||
// Package proxyproto provides shared PROXY protocol v1 (TCP) and v2 (UDP) parsing
|
||||
// and header building utilities used by both the SNI proxy and UDP relay components.
|
||||
package proxyproto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/gerbil/logger"
|
||||
)
|
||||
|
||||
// v2Signature is the 12-byte magic prefix for PROXY protocol v2 headers.
|
||||
var v2Signature = []byte{
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||
}
|
||||
|
||||
// Info holds information parsed from an incoming PROXY protocol header (v1 or v2).
|
||||
type Info struct {
|
||||
Protocol string // e.g. "TCP4", "TCP6", "UDP4", "UDP6"
|
||||
SrcIP string
|
||||
DestIP string
|
||||
SrcPort int
|
||||
DestPort int
|
||||
}
|
||||
|
||||
// Conn wraps a net.Conn so that reads are satisfied from a pre-pended buffered
|
||||
// reader first (remaining bytes after PROXY header parsing) and then from the
|
||||
// underlying connection. All other net.Conn methods are forwarded unchanged.
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
Reader io.Reader
|
||||
}
|
||||
|
||||
// Read satisfies net.Conn, draining the buffered reader before falling through
|
||||
// to the underlying connection.
|
||||
func (c *Conn) Read(b []byte) (int, error) {
|
||||
return c.Reader.Read(b)
|
||||
}
|
||||
|
||||
// IsV2Header returns true when data begins with the 12-byte PROXY protocol v2
|
||||
// magic signature.
|
||||
func IsV2Header(data []byte) bool {
|
||||
if len(data) < 12 {
|
||||
return false
|
||||
}
|
||||
return bytes.Equal(data[:12], v2Signature)
|
||||
}
|
||||
|
||||
// ParseV2UDPHeader tries to parse a PROXY protocol v2 header from the front of
|
||||
// a UDP datagram payload.
|
||||
//
|
||||
// Three return values are provided:
|
||||
// - *Info – filled when a PROXY command header was parsed successfully; nil
|
||||
// for a LOCAL command or unrecognised address family.
|
||||
// - []byte – the remaining payload that follows the header (the actual
|
||||
// application data).
|
||||
// - bool – true when a v2 header was detected (and consumed), false when
|
||||
// no v2 magic is present and data should be treated as-is.
|
||||
func ParseV2UDPHeader(data []byte) (*Info, []byte, bool) {
|
||||
if !IsV2Header(data) {
|
||||
return nil, data, false
|
||||
}
|
||||
|
||||
// Minimum fixed header size: 12 (magic) + 1 (ver/cmd) + 1 (fam/proto) + 2 (len) = 16
|
||||
if len(data) < 16 {
|
||||
return nil, data, false
|
||||
}
|
||||
|
||||
// Byte 12: version (high nibble) + command (low nibble)
|
||||
versionCmd := data[12]
|
||||
version := (versionCmd >> 4) & 0x0F
|
||||
command := versionCmd & 0x0F
|
||||
|
||||
if version != 2 {
|
||||
return nil, data, false
|
||||
}
|
||||
|
||||
// Byte 13: address family (high nibble) + transport protocol (low nibble)
|
||||
familyProto := data[13]
|
||||
family := (familyProto >> 4) & 0x0F
|
||||
protocol := familyProto & 0x0F
|
||||
|
||||
// Bytes 14-15: length of the address block that follows, big-endian
|
||||
addrLen := int(binary.BigEndian.Uint16(data[14:16]))
|
||||
totalHeaderLen := 16 + addrLen
|
||||
|
||||
if len(data) < totalHeaderLen {
|
||||
// Truncated packet – signal that a header was detected but is malformed
|
||||
return nil, data, false
|
||||
}
|
||||
|
||||
payload := data[totalHeaderLen:]
|
||||
|
||||
// LOCAL command (0) carries no address information.
|
||||
if command == 0 {
|
||||
return nil, payload, true
|
||||
}
|
||||
|
||||
if command != 1 {
|
||||
// Unknown command – consume the header and return no info
|
||||
return nil, payload, true
|
||||
}
|
||||
|
||||
addrBlock := data[16:totalHeaderLen]
|
||||
|
||||
var (
|
||||
srcIP, destIP net.IP
|
||||
srcPort uint16
|
||||
destPort uint16
|
||||
protocolStr string
|
||||
)
|
||||
|
||||
switch {
|
||||
case family == 1 && protocol == 1: // AF_INET / STREAM (TCP over IPv4)
|
||||
if len(addrBlock) < 12 {
|
||||
return nil, payload, false
|
||||
}
|
||||
srcIP = net.IP(addrBlock[0:4])
|
||||
destIP = net.IP(addrBlock[4:8])
|
||||
srcPort = binary.BigEndian.Uint16(addrBlock[8:10])
|
||||
destPort = binary.BigEndian.Uint16(addrBlock[10:12])
|
||||
protocolStr = "TCP4"
|
||||
|
||||
case family == 1 && protocol == 2: // AF_INET / DGRAM (UDP over IPv4)
|
||||
if len(addrBlock) < 12 {
|
||||
return nil, payload, false
|
||||
}
|
||||
srcIP = net.IP(addrBlock[0:4])
|
||||
destIP = net.IP(addrBlock[4:8])
|
||||
srcPort = binary.BigEndian.Uint16(addrBlock[8:10])
|
||||
destPort = binary.BigEndian.Uint16(addrBlock[10:12])
|
||||
protocolStr = "UDP4"
|
||||
|
||||
case family == 2 && protocol == 1: // AF_INET6 / STREAM (TCP over IPv6)
|
||||
if len(addrBlock) < 36 {
|
||||
return nil, payload, false
|
||||
}
|
||||
srcIP = net.IP(addrBlock[0:16])
|
||||
destIP = net.IP(addrBlock[16:32])
|
||||
srcPort = binary.BigEndian.Uint16(addrBlock[32:34])
|
||||
destPort = binary.BigEndian.Uint16(addrBlock[34:36])
|
||||
protocolStr = "TCP6"
|
||||
|
||||
case family == 2 && protocol == 2: // AF_INET6 / DGRAM (UDP over IPv6)
|
||||
if len(addrBlock) < 36 {
|
||||
return nil, payload, false
|
||||
}
|
||||
srcIP = net.IP(addrBlock[0:16])
|
||||
destIP = net.IP(addrBlock[16:32])
|
||||
srcPort = binary.BigEndian.Uint16(addrBlock[32:34])
|
||||
destPort = binary.BigEndian.Uint16(addrBlock[34:36])
|
||||
protocolStr = "UDP6"
|
||||
|
||||
default:
|
||||
// UNSPEC or AF_UNIX – consume the header, no address info available
|
||||
return nil, payload, true
|
||||
}
|
||||
|
||||
info := &Info{
|
||||
Protocol: protocolStr,
|
||||
SrcIP: srcIP.String(),
|
||||
DestIP: destIP.String(),
|
||||
SrcPort: int(srcPort),
|
||||
DestPort: int(destPort),
|
||||
}
|
||||
return info, payload, true
|
||||
}
|
||||
|
||||
// ParseV1Header attempts to parse a PROXY protocol v1 (text) header from the
|
||||
// given TCP connection.
|
||||
//
|
||||
// The function first checks whether the remote address appears in
|
||||
// trustedUpstreams. If it does not, it returns (nil, conn, nil) and the caller
|
||||
// should treat the connection as a plain (non-proxied) connection.
|
||||
//
|
||||
// When a trusted upstream is detected the function reads up to 512 bytes,
|
||||
// locates the CRLF-terminated header line, and parses the proxy information.
|
||||
// Whatever bytes were consumed (including any data beyond the header line) are
|
||||
// re-prepended via a *Conn wrapper so that subsequent reads by the caller are
|
||||
// transparent.
|
||||
//
|
||||
// Return values:
|
||||
// - *Info – non-nil when a valid PROXY header was parsed.
|
||||
// - net.Conn – always a valid connection (possibly a *Conn wrapper).
|
||||
// - error – non-nil only on hard failures (e.g. bad port numbers).
|
||||
func ParseV1Header(conn net.Conn, trustedUpstreams map[string]struct{}) (*Info, net.Conn, error) {
|
||||
remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return nil, conn, fmt.Errorf("failed to parse remote address: %w", err)
|
||||
}
|
||||
|
||||
if _, isTrusted := trustedUpstreams[remoteHost]; !isTrusted {
|
||||
return nil, conn, nil
|
||||
}
|
||||
|
||||
// Give the upstream 5 s to deliver the PROXY header before timing out.
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return nil, conn, fmt.Errorf("failed to set read deadline: %w", err)
|
||||
}
|
||||
|
||||
// The PROXY v1 spec mandates the header fits in 108 bytes; 512 is generous.
|
||||
buffer := make([]byte, 512)
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
logger.Debug("Could not read from trusted upstream %s, treating as regular connection: %v", remoteHost, err)
|
||||
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", clearErr)
|
||||
}
|
||||
return nil, conn, nil
|
||||
}
|
||||
|
||||
// Locate the CRLF that terminates the PROXY header line.
|
||||
headerEnd := bytes.Index(buffer[:n], []byte("\r\n"))
|
||||
if headerEnd == -1 {
|
||||
logger.Debug("No PROXY protocol header from trusted upstream %s, treating as regular TLS connection", remoteHost)
|
||||
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", clearErr)
|
||||
}
|
||||
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
|
||||
return nil, &Conn{Conn: conn, Reader: newReader}, nil
|
||||
}
|
||||
|
||||
headerLine := string(buffer[:headerEnd])
|
||||
remainingData := buffer[headerEnd+2 : n]
|
||||
|
||||
parts := strings.Fields(headerLine)
|
||||
|
||||
// Handle "PROXY UNKNOWN" – upstream knows the real source but we don't need it.
|
||||
if len(parts) == 2 && parts[0] == "PROXY" && parts[1] == "UNKNOWN" {
|
||||
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", clearErr)
|
||||
}
|
||||
var newConn net.Conn
|
||||
if len(remainingData) > 0 {
|
||||
newConn = &Conn{Conn: conn, Reader: io.MultiReader(bytes.NewReader(remainingData), conn)}
|
||||
} else {
|
||||
newConn = conn
|
||||
}
|
||||
return nil, newConn, nil
|
||||
}
|
||||
|
||||
if len(parts) != 6 || parts[0] != "PROXY" {
|
||||
// Malformed line from a trusted upstream – re-prepend everything and
|
||||
// let the caller deal with it as a plain TLS connection.
|
||||
logger.Debug("Invalid PROXY protocol from trusted upstream %s, treating as regular TLS connection: %s", remoteHost, headerLine)
|
||||
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", clearErr)
|
||||
}
|
||||
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
|
||||
return nil, &Conn{Conn: conn, Reader: newReader}, nil
|
||||
}
|
||||
|
||||
protocol := parts[1]
|
||||
srcIP := parts[2]
|
||||
destIP := parts[3]
|
||||
|
||||
srcPort, err := strconv.Atoi(parts[4])
|
||||
if err != nil {
|
||||
return nil, conn, fmt.Errorf("invalid source port in PROXY header: %s", parts[4])
|
||||
}
|
||||
destPort, err := strconv.Atoi(parts[5])
|
||||
if err != nil {
|
||||
return nil, conn, fmt.Errorf("invalid destination port in PROXY header: %s", parts[5])
|
||||
}
|
||||
|
||||
// Re-assemble a reader that returns any bytes read beyond the header first.
|
||||
var newReader io.Reader
|
||||
if len(remainingData) > 0 {
|
||||
newReader = io.MultiReader(bytes.NewReader(remainingData), conn)
|
||||
} else {
|
||||
newReader = conn
|
||||
}
|
||||
wrappedConn := &Conn{Conn: conn, Reader: newReader}
|
||||
|
||||
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
|
||||
return nil, conn, fmt.Errorf("failed to clear read deadline: %w", clearErr)
|
||||
}
|
||||
|
||||
info := &Info{
|
||||
Protocol: protocol,
|
||||
SrcIP: srcIP,
|
||||
DestIP: destIP,
|
||||
SrcPort: srcPort,
|
||||
DestPort: destPort,
|
||||
}
|
||||
return info, wrappedConn, nil
|
||||
}
|
||||
|
||||
// BuildV1Header constructs a PROXY protocol v1 header string from two TCP
|
||||
// addresses, normalising the protocol family so that v1's constraint of a
|
||||
// single family per header is satisfied.
|
||||
func BuildV1Header(clientAddr, targetAddr net.Addr) string {
|
||||
clientTCP, ok := clientAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
targetTCP, ok := targetAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
var protocol, targetIP string
|
||||
|
||||
if clientTCP.IP.To4() != nil {
|
||||
// IPv4 client
|
||||
protocol = "TCP4"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
targetIP = targetTCP.IP.String()
|
||||
} else if targetTCP.IP.IsLoopback() {
|
||||
targetIP = "127.0.0.1"
|
||||
} else {
|
||||
targetIP = "127.0.0.1" // safe fallback for mixed-family
|
||||
}
|
||||
} else {
|
||||
// IPv6 client
|
||||
protocol = "TCP6"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
targetIP = "::ffff:" + targetTCP.IP.String()
|
||||
} else {
|
||||
targetIP = targetTCP.IP.String()
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
|
||||
protocol, clientTCP.IP.String(), targetIP, clientTCP.Port, targetTCP.Port)
|
||||
}
|
||||
|
||||
// BuildV1HeaderFromInfo constructs a PROXY protocol v1 header string using a
|
||||
// previously-parsed *Info (i.e. when this server itself sits behind an
|
||||
// upstream proxy) and the target TCP address.
|
||||
func BuildV1HeaderFromInfo(info *Info, targetAddr net.Addr) string {
|
||||
targetTCP, ok := targetAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
srcIP := net.ParseIP(info.SrcIP)
|
||||
if srcIP == nil {
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
var protocol, targetIP string
|
||||
|
||||
if srcIP.To4() != nil {
|
||||
protocol = "TCP4"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
targetIP = targetTCP.IP.String()
|
||||
} else if targetTCP.IP.IsLoopback() {
|
||||
targetIP = "127.0.0.1"
|
||||
} else {
|
||||
targetIP = "127.0.0.1"
|
||||
}
|
||||
} else {
|
||||
protocol = "TCP6"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
targetIP = "::ffff:" + targetTCP.IP.String()
|
||||
} else {
|
||||
targetIP = targetTCP.IP.String()
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
|
||||
protocol, info.SrcIP, targetIP, info.SrcPort, targetTCP.Port)
|
||||
}
|
||||
Reference in New Issue
Block a user