mirror of
https://github.com/fosrl/gerbil.git
synced 2026-05-12 19:29:55 +00:00
Add pp to relay
This commit is contained in:
32
main.go
32
main.go
@@ -355,21 +355,8 @@ func main() {
|
||||
return periodicBandwidthCheck(groupCtx, remoteConfigURL+"/gerbil/receive-bandwidth")
|
||||
})
|
||||
|
||||
// Start the UDP proxy server
|
||||
relayPort := wgconfig.RelayPort
|
||||
if relayPort == 0 {
|
||||
relayPort = 21820 // in case there is no relay port set, use 21820
|
||||
}
|
||||
proxyRelay = relay.NewUDPProxyServer(groupCtx, fmt.Sprintf(":%d", relayPort), remoteConfigURL, key, reachableAt)
|
||||
err = proxyRelay.Start()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to start UDP proxy server: %v", err)
|
||||
}
|
||||
defer proxyRelay.Stop()
|
||||
|
||||
// TODO: WE SHOULD PULL THIS OUT OF THE CONFIG OR SOMETHING
|
||||
// SO YOU DON'T NEED TO SET THIS SEPARATELY
|
||||
// Parse local overrides
|
||||
// Parse local overrides and trusted upstreams early so that both the relay
|
||||
// and the SNI proxy share the same configuration values.
|
||||
var localOverrides []string
|
||||
if localOverridesStr != "" {
|
||||
localOverrides = strings.Split(localOverridesStr, ",")
|
||||
@@ -388,6 +375,21 @@ func main() {
|
||||
logger.Info("Trusted upstreams configured: %v", trustedUpstreams)
|
||||
}
|
||||
|
||||
// Start the UDP proxy server.
|
||||
// proxyProtocol and trustedUpstreams are forwarded so the relay can strip
|
||||
// PROXY protocol v2 headers from load-balancer traffic and recover the
|
||||
// original client IP for hole-punch registration.
|
||||
relayPort := wgconfig.RelayPort
|
||||
if relayPort == 0 {
|
||||
relayPort = 21820 // in case there is no relay port set, use 21820
|
||||
}
|
||||
proxyRelay = relay.NewUDPProxyServer(groupCtx, fmt.Sprintf(":%d", relayPort), remoteConfigURL, key, reachableAt, proxyProtocol, trustedUpstreams)
|
||||
err = proxyRelay.Start()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to start UDP proxy server: %v", err)
|
||||
}
|
||||
defer proxyRelay.Stop()
|
||||
|
||||
proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides, proxyProtocol, trustedUpstreams)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create proxy: %v", err)
|
||||
|
||||
263
proxy/proxy.go
263
proxy/proxy.go
@@ -11,12 +11,12 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/gerbil/logger"
|
||||
"github.com/fosrl/gerbil/proxyproto"
|
||||
"github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
@@ -32,16 +32,6 @@ type RouteAPIResponse struct {
|
||||
Endpoints []string `json:"endpoints"`
|
||||
}
|
||||
|
||||
// ProxyProtocolInfo holds information parsed from incoming PROXY protocol header
|
||||
type ProxyProtocolInfo struct {
|
||||
Protocol string // TCP4 or TCP6
|
||||
SrcIP string
|
||||
DestIP string
|
||||
SrcPort int
|
||||
DestPort int
|
||||
OriginalConn net.Conn // The original connection after PROXY protocol parsing
|
||||
}
|
||||
|
||||
// SNIProxy represents the main proxy server
|
||||
type SNIProxy struct {
|
||||
port int
|
||||
@@ -89,249 +79,6 @@ func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
// parseProxyProtocolHeader parses a PROXY protocol v1 header from the connection
|
||||
func (p *SNIProxy) parseProxyProtocolHeader(conn net.Conn) (*ProxyProtocolInfo, net.Conn, error) {
|
||||
// Check if the connection comes from a trusted upstream
|
||||
remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return nil, conn, fmt.Errorf("failed to parse remote address: %w", err)
|
||||
}
|
||||
|
||||
// Resolve the remote IP to hostname to check if it's trusted
|
||||
// For simplicity, we'll check the IP directly in trusted upstreams
|
||||
// In production, you might want to do reverse DNS lookup
|
||||
if _, isTrusted := p.trustedUpstreams[remoteHost]; !isTrusted {
|
||||
// Not from trusted upstream, return original connection
|
||||
return nil, conn, nil
|
||||
}
|
||||
|
||||
// Set read timeout for PROXY protocol parsing
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return nil, conn, fmt.Errorf("failed to set read deadline: %w", err)
|
||||
}
|
||||
|
||||
// Read the first line (PROXY protocol header)
|
||||
buffer := make([]byte, 512) // PROXY protocol header should be much smaller
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
// If we can't read from trusted upstream, treat as regular connection
|
||||
logger.Debug("Could not read from trusted upstream %s, treating as regular connection: %v", remoteHost, err)
|
||||
// Clear read timeout before returning
|
||||
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", clearErr)
|
||||
}
|
||||
return nil, conn, nil
|
||||
}
|
||||
|
||||
// Find the end of the first line (CRLF)
|
||||
headerEnd := bytes.Index(buffer[:n], []byte("\r\n"))
|
||||
if headerEnd == -1 {
|
||||
// No PROXY protocol header found, treat as regular TLS connection
|
||||
// Return the connection with the buffered data prepended
|
||||
logger.Debug("No PROXY protocol header from trusted upstream %s, treating as regular TLS connection", remoteHost)
|
||||
|
||||
// Clear read timeout
|
||||
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", err)
|
||||
}
|
||||
|
||||
// Create a reader that includes the buffered data + original connection
|
||||
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
|
||||
wrappedConn := &proxyProtocolConn{
|
||||
Conn: conn,
|
||||
reader: newReader,
|
||||
}
|
||||
return nil, wrappedConn, nil
|
||||
}
|
||||
|
||||
headerLine := string(buffer[:headerEnd])
|
||||
remainingData := buffer[headerEnd+2 : n]
|
||||
|
||||
// Parse PROXY protocol line: "PROXY TCP4/TCP6 srcIP destIP srcPort destPort"
|
||||
parts := strings.Fields(headerLine)
|
||||
if len(parts) != 6 || parts[0] != "PROXY" {
|
||||
// Check for PROXY UNKNOWN
|
||||
if len(parts) == 2 && parts[0] == "PROXY" && parts[1] == "UNKNOWN" {
|
||||
// PROXY UNKNOWN - use original connection info
|
||||
return nil, conn, nil
|
||||
}
|
||||
// Invalid PROXY protocol, but might be regular TLS - treat as such
|
||||
logger.Debug("Invalid PROXY protocol from trusted upstream %s, treating as regular TLS connection: %s", remoteHost, headerLine)
|
||||
|
||||
// Clear read timeout
|
||||
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", err)
|
||||
}
|
||||
|
||||
// Return the connection with all buffered data prepended
|
||||
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
|
||||
wrappedConn := &proxyProtocolConn{
|
||||
Conn: conn,
|
||||
reader: newReader,
|
||||
}
|
||||
return nil, wrappedConn, nil
|
||||
}
|
||||
|
||||
protocol := parts[1]
|
||||
srcIP := parts[2]
|
||||
destIP := parts[3]
|
||||
srcPort, err := strconv.Atoi(parts[4])
|
||||
if err != nil {
|
||||
return nil, conn, fmt.Errorf("invalid source port in PROXY header: %s", parts[4])
|
||||
}
|
||||
destPort, err := strconv.Atoi(parts[5])
|
||||
if err != nil {
|
||||
return nil, conn, fmt.Errorf("invalid destination port in PROXY header: %s", parts[5])
|
||||
}
|
||||
|
||||
// Create a new reader that includes remaining data + original connection
|
||||
var newReader io.Reader
|
||||
if len(remainingData) > 0 {
|
||||
newReader = io.MultiReader(bytes.NewReader(remainingData), conn)
|
||||
} else {
|
||||
newReader = conn
|
||||
}
|
||||
|
||||
// Create a wrapper connection that reads from the combined reader
|
||||
wrappedConn := &proxyProtocolConn{
|
||||
Conn: conn,
|
||||
reader: newReader,
|
||||
}
|
||||
|
||||
proxyInfo := &ProxyProtocolInfo{
|
||||
Protocol: protocol,
|
||||
SrcIP: srcIP,
|
||||
DestIP: destIP,
|
||||
SrcPort: srcPort,
|
||||
DestPort: destPort,
|
||||
OriginalConn: wrappedConn,
|
||||
}
|
||||
|
||||
// Clear read timeout
|
||||
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
return nil, conn, fmt.Errorf("failed to clear read deadline: %w", err)
|
||||
}
|
||||
|
||||
return proxyInfo, wrappedConn, nil
|
||||
}
|
||||
|
||||
// proxyProtocolConn wraps a connection to read from a custom reader
|
||||
type proxyProtocolConn struct {
|
||||
net.Conn
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func (c *proxyProtocolConn) Read(b []byte) (int, error) {
|
||||
return c.reader.Read(b)
|
||||
}
|
||||
|
||||
// buildProxyProtocolHeaderFromInfo creates a PROXY protocol v1 header using ProxyProtocolInfo
|
||||
func (p *SNIProxy) buildProxyProtocolHeaderFromInfo(proxyInfo *ProxyProtocolInfo, targetAddr net.Addr) string {
|
||||
targetTCP, ok := targetAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
// Fallback for unknown address types
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
// Use the original client information from the PROXY protocol
|
||||
var targetIP string
|
||||
var protocol string
|
||||
|
||||
// Parse source IP to determine protocol family
|
||||
srcIP := net.ParseIP(proxyInfo.SrcIP)
|
||||
if srcIP == nil {
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
if srcIP.To4() != nil {
|
||||
// Source is IPv4, use TCP4 protocol
|
||||
protocol = "TCP4"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
// Target is also IPv4, use as-is
|
||||
targetIP = targetTCP.IP.String()
|
||||
} else {
|
||||
// Target is IPv6, but we need IPv4 for consistent protocol family
|
||||
if targetTCP.IP.IsLoopback() {
|
||||
targetIP = "127.0.0.1"
|
||||
} else {
|
||||
targetIP = "127.0.0.1" // Safe fallback
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Source is IPv6, use TCP6 protocol
|
||||
protocol = "TCP6"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
// Target is IPv4, convert to IPv6 representation
|
||||
targetIP = "::ffff:" + targetTCP.IP.String()
|
||||
} else {
|
||||
// Target is also IPv6, use as-is
|
||||
targetIP = targetTCP.IP.String()
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
|
||||
protocol,
|
||||
proxyInfo.SrcIP,
|
||||
targetIP,
|
||||
proxyInfo.SrcPort,
|
||||
targetTCP.Port)
|
||||
}
|
||||
|
||||
// buildProxyProtocolHeader creates a PROXY protocol v1 header
|
||||
func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
|
||||
clientTCP, ok := clientAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
// Fallback for unknown address types
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
targetTCP, ok := targetAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
// Fallback for unknown address types
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
// Determine protocol family based on client IP and normalize target IP accordingly
|
||||
var protocol string
|
||||
var targetIP string
|
||||
|
||||
if clientTCP.IP.To4() != nil {
|
||||
// Client is IPv4, use TCP4 protocol
|
||||
protocol = "TCP4"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
// Target is also IPv4, use as-is
|
||||
targetIP = targetTCP.IP.String()
|
||||
} else {
|
||||
// Target is IPv6, but we need IPv4 for consistent protocol family
|
||||
// Use the IPv4 loopback if target is IPv6 loopback, otherwise use 127.0.0.1
|
||||
if targetTCP.IP.IsLoopback() {
|
||||
targetIP = "127.0.0.1"
|
||||
} else {
|
||||
// For non-loopback IPv6 targets, we could try to extract embedded IPv4
|
||||
// or fall back to a sensible IPv4 address based on the target
|
||||
targetIP = "127.0.0.1" // Safe fallback
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Client is IPv6, use TCP6 protocol
|
||||
protocol = "TCP6"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
// Target is IPv4, convert to IPv6 representation
|
||||
targetIP = "::ffff:" + targetTCP.IP.String()
|
||||
} else {
|
||||
// Target is also IPv6, use as-is
|
||||
targetIP = targetTCP.IP.String()
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
|
||||
protocol,
|
||||
clientTCP.IP.String(),
|
||||
targetIP,
|
||||
clientTCP.Port,
|
||||
targetTCP.Port)
|
||||
}
|
||||
|
||||
// NewSNIProxy creates a new SNI proxy instance
|
||||
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool, trustedUpstreams []string) (*SNIProxy, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -490,12 +237,12 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
logger.Debug("Accepted connection from %s", clientConn.RemoteAddr())
|
||||
|
||||
// Check for PROXY protocol from trusted upstream
|
||||
var proxyInfo *ProxyProtocolInfo
|
||||
var proxyInfo *proxyproto.Info
|
||||
var actualClientConn net.Conn = clientConn
|
||||
|
||||
if len(p.trustedUpstreams) > 0 {
|
||||
var err error
|
||||
proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn)
|
||||
proxyInfo, actualClientConn, err = proxyproto.ParseV1Header(clientConn, p.trustedUpstreams)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to parse PROXY protocol: %v", err)
|
||||
return
|
||||
@@ -575,10 +322,10 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
var proxyHeader string
|
||||
if proxyInfo != nil {
|
||||
// Use original client info from PROXY protocol
|
||||
proxyHeader = p.buildProxyProtocolHeaderFromInfo(proxyInfo, targetConn.LocalAddr())
|
||||
proxyHeader = proxyproto.BuildV1HeaderFromInfo(proxyInfo, targetConn.LocalAddr())
|
||||
} else {
|
||||
// Use direct client connection info
|
||||
proxyHeader = buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr())
|
||||
proxyHeader = proxyproto.BuildV1Header(clientConn.RemoteAddr(), targetConn.LocalAddr())
|
||||
}
|
||||
logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader))
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ package proxy
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/fosrl/gerbil/proxyproto"
|
||||
)
|
||||
|
||||
func TestBuildProxyProtocolHeader(t *testing.T) {
|
||||
@@ -56,7 +58,7 @@ func TestBuildProxyProtocolHeader(t *testing.T) {
|
||||
t.Fatalf("Failed to resolve target address: %v", err)
|
||||
}
|
||||
|
||||
result := buildProxyProtocolHeader(clientTCP, targetTCP)
|
||||
result := proxyproto.BuildV1Header(clientTCP, targetTCP)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %q, got %q", tt.expected, result)
|
||||
}
|
||||
@@ -69,7 +71,7 @@ func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) {
|
||||
clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345}
|
||||
targetAddr := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443}
|
||||
|
||||
result := buildProxyProtocolHeader(clientAddr, targetAddr)
|
||||
result := proxyproto.BuildV1Header(clientAddr, targetAddr)
|
||||
expected := "PROXY UNKNOWN\r\n"
|
||||
|
||||
if result != expected {
|
||||
@@ -78,13 +80,8 @@ func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) {
|
||||
proxy, err := NewSNIProxy(8443, "", "", "127.0.0.1", 443, nil, true, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SNI proxy: %v", err)
|
||||
}
|
||||
|
||||
// Test IPv4 case
|
||||
proxyInfo := &ProxyProtocolInfo{
|
||||
info := &proxyproto.Info{
|
||||
Protocol: "TCP4",
|
||||
SrcIP: "10.0.0.1",
|
||||
DestIP: "192.168.1.100",
|
||||
@@ -93,7 +90,7 @@ func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) {
|
||||
}
|
||||
|
||||
targetAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8080")
|
||||
header := proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr)
|
||||
header := proxyproto.BuildV1HeaderFromInfo(info, targetAddr)
|
||||
|
||||
expected := "PROXY TCP4 10.0.0.1 127.0.0.1 12345 8080\r\n"
|
||||
if header != expected {
|
||||
@@ -101,7 +98,7 @@ func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test IPv6 case
|
||||
proxyInfo = &ProxyProtocolInfo{
|
||||
info = &proxyproto.Info{
|
||||
Protocol: "TCP6",
|
||||
SrcIP: "2001:db8::1",
|
||||
DestIP: "2001:db8::2",
|
||||
@@ -110,10 +107,99 @@ func TestBuildProxyProtocolHeaderFromInfo(t *testing.T) {
|
||||
}
|
||||
|
||||
targetAddr, _ = net.ResolveTCPAddr("tcp6", "[::1]:8080")
|
||||
header = proxy.buildProxyProtocolHeaderFromInfo(proxyInfo, targetAddr)
|
||||
header = proxyproto.BuildV1HeaderFromInfo(info, targetAddr)
|
||||
|
||||
expected = "PROXY TCP6 2001:db8::1 ::1 12345 8080\r\n"
|
||||
if header != expected {
|
||||
t.Errorf("Expected header '%s', got '%s'", expected, header)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseV2UDPHeader(t *testing.T) {
|
||||
// Build a minimal PROXY v2 header for IPv4 UDP
|
||||
// Magic (12) + ver/cmd (1) + fam/proto (1) + len (2) + src IP (4) + dst IP (4) + src port (2) + dst port (2) = 28 bytes
|
||||
header := []byte{
|
||||
// Magic signature
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||
// Version 2 (0x2x), PROXY command (0x01)
|
||||
0x21,
|
||||
// AF_INET (0x1x), DGRAM/UDP (0x02)
|
||||
0x12,
|
||||
// Address block length: 12 bytes (4+4+2+2)
|
||||
0x00, 0x0C,
|
||||
// Source IP: 192.168.1.100
|
||||
192, 168, 1, 100,
|
||||
// Destination IP: 10.0.0.1
|
||||
10, 0, 0, 1,
|
||||
// Source port: 4500
|
||||
0x11, 0x94,
|
||||
// Destination port: 21820
|
||||
0x55, 0x3C,
|
||||
}
|
||||
|
||||
// Append a fake application payload
|
||||
payload := []byte{0x01, 0x02, 0x03}
|
||||
data := append(header, payload...)
|
||||
|
||||
info, remaining, ok := proxyproto.ParseV2UDPHeader(data)
|
||||
if !ok {
|
||||
t.Fatal("Expected ParseV2UDPHeader to return ok=true")
|
||||
}
|
||||
if info == nil {
|
||||
t.Fatal("Expected non-nil Info")
|
||||
}
|
||||
if info.Protocol != "UDP4" {
|
||||
t.Errorf("Expected protocol UDP4, got %s", info.Protocol)
|
||||
}
|
||||
if info.SrcIP != "192.168.1.100" {
|
||||
t.Errorf("Expected SrcIP 192.168.1.100, got %s", info.SrcIP)
|
||||
}
|
||||
if info.DestIP != "10.0.0.1" {
|
||||
t.Errorf("Expected DestIP 10.0.0.1, got %s", info.DestIP)
|
||||
}
|
||||
if info.SrcPort != 4500 {
|
||||
t.Errorf("Expected SrcPort 4500, got %d", info.SrcPort)
|
||||
}
|
||||
if info.DestPort != 21820 {
|
||||
t.Errorf("Expected DestPort 21820, got %d", info.DestPort)
|
||||
}
|
||||
if len(remaining) != len(payload) {
|
||||
t.Errorf("Expected %d remaining bytes, got %d", len(payload), len(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseV2UDPHeaderNoHeader(t *testing.T) {
|
||||
// Data that does NOT start with v2 magic should be returned as-is
|
||||
data := []byte{0x01, 0x02, 0x03}
|
||||
info, remaining, ok := proxyproto.ParseV2UDPHeader(data)
|
||||
if ok {
|
||||
t.Error("Expected ok=false for non-v2 data")
|
||||
}
|
||||
if info != nil {
|
||||
t.Error("Expected nil Info for non-v2 data")
|
||||
}
|
||||
if len(remaining) != len(data) {
|
||||
t.Errorf("Expected remaining to equal original data length %d, got %d", len(data), len(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsV2Header(t *testing.T) {
|
||||
valid := []byte{
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||
// extra bytes beyond the magic
|
||||
0x21, 0x12,
|
||||
}
|
||||
if !proxyproto.IsV2Header(valid) {
|
||||
t.Error("Expected IsV2Header=true for valid magic")
|
||||
}
|
||||
|
||||
invalid := []byte{0x01, 0x02, 0x03}
|
||||
if proxyproto.IsV2Header(invalid) {
|
||||
t.Error("Expected IsV2Header=false for non-magic data")
|
||||
}
|
||||
|
||||
tooShort := []byte{0x0D, 0x0A}
|
||||
if proxyproto.IsV2Header(tooShort) {
|
||||
t.Error("Expected IsV2Header=false for too-short data")
|
||||
}
|
||||
}
|
||||
370
proxyproto/proxyproto.go
Normal file
370
proxyproto/proxyproto.go
Normal file
@@ -0,0 +1,370 @@
|
||||
// Package proxyproto provides shared PROXY protocol v1 (TCP) and v2 (UDP) parsing
|
||||
// and header building utilities used by both the SNI proxy and UDP relay components.
|
||||
package proxyproto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/gerbil/logger"
|
||||
)
|
||||
|
||||
// v2Signature is the 12-byte magic prefix for PROXY protocol v2 headers.
|
||||
var v2Signature = []byte{
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||
}
|
||||
|
||||
// Info holds information parsed from an incoming PROXY protocol header (v1 or v2).
|
||||
type Info struct {
|
||||
Protocol string // e.g. "TCP4", "TCP6", "UDP4", "UDP6"
|
||||
SrcIP string
|
||||
DestIP string
|
||||
SrcPort int
|
||||
DestPort int
|
||||
}
|
||||
|
||||
// Conn wraps a net.Conn so that reads are satisfied from a pre-pended buffered
|
||||
// reader first (remaining bytes after PROXY header parsing) and then from the
|
||||
// underlying connection. All other net.Conn methods are forwarded unchanged.
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
Reader io.Reader
|
||||
}
|
||||
|
||||
// Read satisfies net.Conn, draining the buffered reader before falling through
|
||||
// to the underlying connection.
|
||||
func (c *Conn) Read(b []byte) (int, error) {
|
||||
return c.Reader.Read(b)
|
||||
}
|
||||
|
||||
// IsV2Header returns true when data begins with the 12-byte PROXY protocol v2
|
||||
// magic signature.
|
||||
func IsV2Header(data []byte) bool {
|
||||
if len(data) < 12 {
|
||||
return false
|
||||
}
|
||||
return bytes.Equal(data[:12], v2Signature)
|
||||
}
|
||||
|
||||
// ParseV2UDPHeader tries to parse a PROXY protocol v2 header from the front of
|
||||
// a UDP datagram payload.
|
||||
//
|
||||
// Three return values are provided:
|
||||
// - *Info – filled when a PROXY command header was parsed successfully; nil
|
||||
// for a LOCAL command or unrecognised address family.
|
||||
// - []byte – the remaining payload that follows the header (the actual
|
||||
// application data).
|
||||
// - bool – true when a v2 header was detected (and consumed), false when
|
||||
// no v2 magic is present and data should be treated as-is.
|
||||
func ParseV2UDPHeader(data []byte) (*Info, []byte, bool) {
|
||||
if !IsV2Header(data) {
|
||||
return nil, data, false
|
||||
}
|
||||
|
||||
// Minimum fixed header size: 12 (magic) + 1 (ver/cmd) + 1 (fam/proto) + 2 (len) = 16
|
||||
if len(data) < 16 {
|
||||
return nil, data, false
|
||||
}
|
||||
|
||||
// Byte 12: version (high nibble) + command (low nibble)
|
||||
versionCmd := data[12]
|
||||
version := (versionCmd >> 4) & 0x0F
|
||||
command := versionCmd & 0x0F
|
||||
|
||||
if version != 2 {
|
||||
return nil, data, false
|
||||
}
|
||||
|
||||
// Byte 13: address family (high nibble) + transport protocol (low nibble)
|
||||
familyProto := data[13]
|
||||
family := (familyProto >> 4) & 0x0F
|
||||
protocol := familyProto & 0x0F
|
||||
|
||||
// Bytes 14-15: length of the address block that follows, big-endian
|
||||
addrLen := int(binary.BigEndian.Uint16(data[14:16]))
|
||||
totalHeaderLen := 16 + addrLen
|
||||
|
||||
if len(data) < totalHeaderLen {
|
||||
// Truncated packet – signal that a header was detected but is malformed
|
||||
return nil, data, false
|
||||
}
|
||||
|
||||
payload := data[totalHeaderLen:]
|
||||
|
||||
// LOCAL command (0) carries no address information.
|
||||
if command == 0 {
|
||||
return nil, payload, true
|
||||
}
|
||||
|
||||
if command != 1 {
|
||||
// Unknown command – consume the header and return no info
|
||||
return nil, payload, true
|
||||
}
|
||||
|
||||
addrBlock := data[16:totalHeaderLen]
|
||||
|
||||
var (
|
||||
srcIP, destIP net.IP
|
||||
srcPort uint16
|
||||
destPort uint16
|
||||
protocolStr string
|
||||
)
|
||||
|
||||
switch {
|
||||
case family == 1 && protocol == 1: // AF_INET / STREAM (TCP over IPv4)
|
||||
if len(addrBlock) < 12 {
|
||||
return nil, payload, false
|
||||
}
|
||||
srcIP = net.IP(addrBlock[0:4])
|
||||
destIP = net.IP(addrBlock[4:8])
|
||||
srcPort = binary.BigEndian.Uint16(addrBlock[8:10])
|
||||
destPort = binary.BigEndian.Uint16(addrBlock[10:12])
|
||||
protocolStr = "TCP4"
|
||||
|
||||
case family == 1 && protocol == 2: // AF_INET / DGRAM (UDP over IPv4)
|
||||
if len(addrBlock) < 12 {
|
||||
return nil, payload, false
|
||||
}
|
||||
srcIP = net.IP(addrBlock[0:4])
|
||||
destIP = net.IP(addrBlock[4:8])
|
||||
srcPort = binary.BigEndian.Uint16(addrBlock[8:10])
|
||||
destPort = binary.BigEndian.Uint16(addrBlock[10:12])
|
||||
protocolStr = "UDP4"
|
||||
|
||||
case family == 2 && protocol == 1: // AF_INET6 / STREAM (TCP over IPv6)
|
||||
if len(addrBlock) < 36 {
|
||||
return nil, payload, false
|
||||
}
|
||||
srcIP = net.IP(addrBlock[0:16])
|
||||
destIP = net.IP(addrBlock[16:32])
|
||||
srcPort = binary.BigEndian.Uint16(addrBlock[32:34])
|
||||
destPort = binary.BigEndian.Uint16(addrBlock[34:36])
|
||||
protocolStr = "TCP6"
|
||||
|
||||
case family == 2 && protocol == 2: // AF_INET6 / DGRAM (UDP over IPv6)
|
||||
if len(addrBlock) < 36 {
|
||||
return nil, payload, false
|
||||
}
|
||||
srcIP = net.IP(addrBlock[0:16])
|
||||
destIP = net.IP(addrBlock[16:32])
|
||||
srcPort = binary.BigEndian.Uint16(addrBlock[32:34])
|
||||
destPort = binary.BigEndian.Uint16(addrBlock[34:36])
|
||||
protocolStr = "UDP6"
|
||||
|
||||
default:
|
||||
// UNSPEC or AF_UNIX – consume the header, no address info available
|
||||
return nil, payload, true
|
||||
}
|
||||
|
||||
info := &Info{
|
||||
Protocol: protocolStr,
|
||||
SrcIP: srcIP.String(),
|
||||
DestIP: destIP.String(),
|
||||
SrcPort: int(srcPort),
|
||||
DestPort: int(destPort),
|
||||
}
|
||||
return info, payload, true
|
||||
}
|
||||
|
||||
// ParseV1Header attempts to parse a PROXY protocol v1 (text) header from the
|
||||
// given TCP connection.
|
||||
//
|
||||
// The function first checks whether the remote address appears in
|
||||
// trustedUpstreams. If it does not, it returns (nil, conn, nil) and the caller
|
||||
// should treat the connection as a plain (non-proxied) connection.
|
||||
//
|
||||
// When a trusted upstream is detected the function reads up to 512 bytes,
|
||||
// locates the CRLF-terminated header line, and parses the proxy information.
|
||||
// Whatever bytes were consumed (including any data beyond the header line) are
|
||||
// re-prepended via a *Conn wrapper so that subsequent reads by the caller are
|
||||
// transparent.
|
||||
//
|
||||
// Return values:
|
||||
// - *Info – non-nil when a valid PROXY header was parsed.
|
||||
// - net.Conn – always a valid connection (possibly a *Conn wrapper).
|
||||
// - error – non-nil only on hard failures (e.g. bad port numbers).
|
||||
func ParseV1Header(conn net.Conn, trustedUpstreams map[string]struct{}) (*Info, net.Conn, error) {
|
||||
remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return nil, conn, fmt.Errorf("failed to parse remote address: %w", err)
|
||||
}
|
||||
|
||||
if _, isTrusted := trustedUpstreams[remoteHost]; !isTrusted {
|
||||
return nil, conn, nil
|
||||
}
|
||||
|
||||
// Give the upstream 5 s to deliver the PROXY header before timing out.
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return nil, conn, fmt.Errorf("failed to set read deadline: %w", err)
|
||||
}
|
||||
|
||||
// The PROXY v1 spec mandates the header fits in 108 bytes; 512 is generous.
|
||||
buffer := make([]byte, 512)
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
logger.Debug("Could not read from trusted upstream %s, treating as regular connection: %v", remoteHost, err)
|
||||
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", clearErr)
|
||||
}
|
||||
return nil, conn, nil
|
||||
}
|
||||
|
||||
// Locate the CRLF that terminates the PROXY header line.
|
||||
headerEnd := bytes.Index(buffer[:n], []byte("\r\n"))
|
||||
if headerEnd == -1 {
|
||||
logger.Debug("No PROXY protocol header from trusted upstream %s, treating as regular TLS connection", remoteHost)
|
||||
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", clearErr)
|
||||
}
|
||||
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
|
||||
return nil, &Conn{Conn: conn, Reader: newReader}, nil
|
||||
}
|
||||
|
||||
headerLine := string(buffer[:headerEnd])
|
||||
remainingData := buffer[headerEnd+2 : n]
|
||||
|
||||
parts := strings.Fields(headerLine)
|
||||
|
||||
// Handle "PROXY UNKNOWN" – upstream knows the real source but we don't need it.
|
||||
if len(parts) == 2 && parts[0] == "PROXY" && parts[1] == "UNKNOWN" {
|
||||
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", clearErr)
|
||||
}
|
||||
var newConn net.Conn
|
||||
if len(remainingData) > 0 {
|
||||
newConn = &Conn{Conn: conn, Reader: io.MultiReader(bytes.NewReader(remainingData), conn)}
|
||||
} else {
|
||||
newConn = conn
|
||||
}
|
||||
return nil, newConn, nil
|
||||
}
|
||||
|
||||
if len(parts) != 6 || parts[0] != "PROXY" {
|
||||
// Malformed line from a trusted upstream – re-prepend everything and
|
||||
// let the caller deal with it as a plain TLS connection.
|
||||
logger.Debug("Invalid PROXY protocol from trusted upstream %s, treating as regular TLS connection: %s", remoteHost, headerLine)
|
||||
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
|
||||
logger.Debug("Failed to clear read deadline: %v", clearErr)
|
||||
}
|
||||
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
|
||||
return nil, &Conn{Conn: conn, Reader: newReader}, nil
|
||||
}
|
||||
|
||||
protocol := parts[1]
|
||||
srcIP := parts[2]
|
||||
destIP := parts[3]
|
||||
|
||||
srcPort, err := strconv.Atoi(parts[4])
|
||||
if err != nil {
|
||||
return nil, conn, fmt.Errorf("invalid source port in PROXY header: %s", parts[4])
|
||||
}
|
||||
destPort, err := strconv.Atoi(parts[5])
|
||||
if err != nil {
|
||||
return nil, conn, fmt.Errorf("invalid destination port in PROXY header: %s", parts[5])
|
||||
}
|
||||
|
||||
// Re-assemble a reader that returns any bytes read beyond the header first.
|
||||
var newReader io.Reader
|
||||
if len(remainingData) > 0 {
|
||||
newReader = io.MultiReader(bytes.NewReader(remainingData), conn)
|
||||
} else {
|
||||
newReader = conn
|
||||
}
|
||||
wrappedConn := &Conn{Conn: conn, Reader: newReader}
|
||||
|
||||
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
|
||||
return nil, conn, fmt.Errorf("failed to clear read deadline: %w", clearErr)
|
||||
}
|
||||
|
||||
info := &Info{
|
||||
Protocol: protocol,
|
||||
SrcIP: srcIP,
|
||||
DestIP: destIP,
|
||||
SrcPort: srcPort,
|
||||
DestPort: destPort,
|
||||
}
|
||||
return info, wrappedConn, nil
|
||||
}
|
||||
|
||||
// BuildV1Header constructs a PROXY protocol v1 header string from two TCP
|
||||
// addresses, normalising the protocol family so that v1's constraint of a
|
||||
// single family per header is satisfied.
|
||||
func BuildV1Header(clientAddr, targetAddr net.Addr) string {
|
||||
clientTCP, ok := clientAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
targetTCP, ok := targetAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
var protocol, targetIP string
|
||||
|
||||
if clientTCP.IP.To4() != nil {
|
||||
// IPv4 client
|
||||
protocol = "TCP4"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
targetIP = targetTCP.IP.String()
|
||||
} else if targetTCP.IP.IsLoopback() {
|
||||
targetIP = "127.0.0.1"
|
||||
} else {
|
||||
targetIP = "127.0.0.1" // safe fallback for mixed-family
|
||||
}
|
||||
} else {
|
||||
// IPv6 client
|
||||
protocol = "TCP6"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
targetIP = "::ffff:" + targetTCP.IP.String()
|
||||
} else {
|
||||
targetIP = targetTCP.IP.String()
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
|
||||
protocol, clientTCP.IP.String(), targetIP, clientTCP.Port, targetTCP.Port)
|
||||
}
|
||||
|
||||
// BuildV1HeaderFromInfo constructs a PROXY protocol v1 header string using a
|
||||
// previously-parsed *Info (i.e. when this server itself sits behind an
|
||||
// upstream proxy) and the target TCP address.
|
||||
func BuildV1HeaderFromInfo(info *Info, targetAddr net.Addr) string {
|
||||
targetTCP, ok := targetAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
srcIP := net.ParseIP(info.SrcIP)
|
||||
if srcIP == nil {
|
||||
return "PROXY UNKNOWN\r\n"
|
||||
}
|
||||
|
||||
var protocol, targetIP string
|
||||
|
||||
if srcIP.To4() != nil {
|
||||
protocol = "TCP4"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
targetIP = targetTCP.IP.String()
|
||||
} else if targetTCP.IP.IsLoopback() {
|
||||
targetIP = "127.0.0.1"
|
||||
} else {
|
||||
targetIP = "127.0.0.1"
|
||||
}
|
||||
} else {
|
||||
protocol = "TCP6"
|
||||
if targetTCP.IP.To4() != nil {
|
||||
targetIP = "::ffff:" + targetTCP.IP.String()
|
||||
} else {
|
||||
targetIP = targetTCP.IP.String()
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
|
||||
protocol, info.SrcIP, targetIP, info.SrcPort, targetTCP.Port)
|
||||
}
|
||||
100
relay/relay.go
100
relay/relay.go
@@ -10,10 +10,12 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/gerbil/logger"
|
||||
"github.com/fosrl/gerbil/proxyproto"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
@@ -168,19 +170,50 @@ type UDPProxyServer struct {
|
||||
addrCache sync.Map
|
||||
// ReachableAt is the URL where this server can be reached
|
||||
ReachableAt string
|
||||
|
||||
// proxyProtocol enables PROXY protocol v2 header parsing for incoming UDP packets.
|
||||
// When enabled, packets from trustedUpstreams that carry a v2 header will have
|
||||
// their source address overridden with the address reported in the header.
|
||||
proxyProtocol bool
|
||||
trustedUpstreams map[string]struct{}
|
||||
}
|
||||
|
||||
// NewUDPProxyServer initializes the server with a buffered packet channel and derived context.
|
||||
func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
||||
//
|
||||
// proxyProtocol enables PROXY protocol v2 parsing for datagrams arriving from
|
||||
// any address listed in trustedUpstreams (plain IPs or resolvable hostnames).
|
||||
// When a trusted datagram carries a v2 header its source address is replaced
|
||||
// with the address carried inside the header before further processing, so that
|
||||
// hole-punch endpoints reflect the original client IP rather than the load
|
||||
// balancer's address.
|
||||
func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string, proxyProtocol bool, trustedUpstreams []string) *UDPProxyServer {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
|
||||
trustedMap := make(map[string]struct{})
|
||||
for _, upstream := range trustedUpstreams {
|
||||
upstream = strings.TrimSpace(upstream)
|
||||
if upstream == "" {
|
||||
continue
|
||||
}
|
||||
trustedMap[upstream] = struct{}{}
|
||||
// Also resolve any hostnames to their current IPs so we can match by IP.
|
||||
if ips, err := net.LookupIP(upstream); err == nil {
|
||||
for _, ip := range ips {
|
||||
trustedMap[ip.String()] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &UDPProxyServer{
|
||||
addr: addr,
|
||||
serverURL: serverURL,
|
||||
privateKey: privateKey,
|
||||
packetChan: make(chan Packet, 50000), // Increased from 1000 to handle high throughput
|
||||
ReachableAt: reachableAt,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
addr: addr,
|
||||
serverURL: serverURL,
|
||||
privateKey: privateKey,
|
||||
packetChan: make(chan Packet, 50000), // Increased from 1000 to handle high throughput
|
||||
ReachableAt: reachableAt,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
proxyProtocol: proxyProtocol,
|
||||
trustedUpstreams: trustedMap,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,13 +321,48 @@ func (s *UDPProxyServer) readPackets() {
|
||||
// packetWorker processes incoming packets from the channel.
|
||||
func (s *UDPProxyServer) packetWorker() {
|
||||
for packet := range s.packetChan {
|
||||
// Determine packet type by inspecting the first byte.
|
||||
if packet.n > 0 && packet.data[0] >= 1 && packet.data[0] <= 4 {
|
||||
// effectiveData and effectiveAddr represent the application-layer payload
|
||||
// and the true originating address. They start as the raw UDP values and
|
||||
// may be updated below when a PROXY protocol v2 header is present.
|
||||
effectiveData := packet.data[:packet.n]
|
||||
effectiveAddr := packet.remoteAddr
|
||||
|
||||
// ---------- PROXY protocol v2 (UDP) ------------------------------------
|
||||
// If proxy protocol is enabled and this datagram arrives from a trusted
|
||||
// upstream (e.g. a load balancer), attempt to parse the v2 header so
|
||||
// that we use the original client address for hole-punch registration and
|
||||
// WireGuard session tracking rather than the load balancer's address.
|
||||
if s.proxyProtocol && len(s.trustedUpstreams) > 0 {
|
||||
remoteHost := packet.remoteAddr.IP.String()
|
||||
if _, trusted := s.trustedUpstreams[remoteHost]; trusted {
|
||||
if info, payload, ok := proxyproto.ParseV2UDPHeader(effectiveData); ok {
|
||||
if info != nil {
|
||||
// Override source address with what the proxy reported.
|
||||
if srcIP := net.ParseIP(info.SrcIP); srcIP != nil {
|
||||
effectiveAddr = &net.UDPAddr{
|
||||
IP: srcIP,
|
||||
Port: info.SrcPort,
|
||||
}
|
||||
logger.Debug("PROXY protocol v2: overriding source %s → %s:%d",
|
||||
packet.remoteAddr, info.SrcIP, info.SrcPort)
|
||||
}
|
||||
}
|
||||
// Always advance past the header so the remainder is treated
|
||||
// as the real application payload.
|
||||
effectiveData = payload
|
||||
}
|
||||
}
|
||||
}
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// Determine packet type by inspecting the first byte of the (possibly
|
||||
// stripped) application payload.
|
||||
if len(effectiveData) > 0 && effectiveData[0] >= 1 && effectiveData[0] <= 4 {
|
||||
// Process as a WireGuard packet.
|
||||
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
|
||||
s.handleWireGuardPacket(effectiveData, effectiveAddr)
|
||||
} else {
|
||||
// Rate limit: allow at most 2 hole punch messages per IP:Port per second
|
||||
rateLimitKey := packet.remoteAddr.String()
|
||||
rateLimitKey := effectiveAddr.String()
|
||||
entryVal, _ := s.holePunchRateLimiter.LoadOrStore(rateLimitKey, &holePunchRateLimitEntry{
|
||||
windowStart: time.Now(),
|
||||
})
|
||||
@@ -316,7 +384,7 @@ func (s *UDPProxyServer) packetWorker() {
|
||||
|
||||
// Process as an encrypted hole punch message
|
||||
var encMsg EncryptedHolePunchMessage
|
||||
if err := json.Unmarshal(packet.data, &encMsg); err != nil {
|
||||
if err := json.Unmarshal(effectiveData, &encMsg); err != nil {
|
||||
logger.Error("Error unmarshaling encrypted message: %v", err)
|
||||
// Return the buffer to the pool for reuse and continue with next packet
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
@@ -352,14 +420,14 @@ func (s *UDPProxyServer) packetWorker() {
|
||||
NewtID: msg.NewtID,
|
||||
OlmID: msg.OlmID,
|
||||
Token: msg.Token,
|
||||
IP: packet.remoteAddr.IP.String(),
|
||||
Port: packet.remoteAddr.Port,
|
||||
IP: effectiveAddr.IP.String(),
|
||||
Port: effectiveAddr.Port,
|
||||
Timestamp: time.Now().Unix(),
|
||||
ReachableAt: s.ReachableAt,
|
||||
ExitNodePublicKey: s.privateKey.PublicKey().String(),
|
||||
ClientPublicKey: msg.PublicKey,
|
||||
}
|
||||
logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", packet.remoteAddr.String(), endpoint.IP, endpoint.Port)
|
||||
logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", effectiveAddr.String(), endpoint.IP, endpoint.Port)
|
||||
s.notifyServer(endpoint)
|
||||
s.clearSessionsForIP(endpoint.IP) // Clear sessions for this IP to allow re-establishment
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user