mirror of
https://github.com/fosrl/gerbil.git
synced 2026-05-18 22:29:54 +00:00
Add pp to relay
This commit is contained in:
100
relay/relay.go
100
relay/relay.go
@@ -10,10 +10,12 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/gerbil/logger"
|
||||
"github.com/fosrl/gerbil/proxyproto"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
@@ -168,19 +170,50 @@ type UDPProxyServer struct {
|
||||
addrCache sync.Map
|
||||
// ReachableAt is the URL where this server can be reached
|
||||
ReachableAt string
|
||||
|
||||
// proxyProtocol enables PROXY protocol v2 header parsing for incoming UDP packets.
|
||||
// When enabled, packets from trustedUpstreams that carry a v2 header will have
|
||||
// their source address overridden with the address reported in the header.
|
||||
proxyProtocol bool
|
||||
trustedUpstreams map[string]struct{}
|
||||
}
|
||||
|
||||
// NewUDPProxyServer initializes the server with a buffered packet channel and derived context.
|
||||
func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
||||
//
|
||||
// proxyProtocol enables PROXY protocol v2 parsing for datagrams arriving from
|
||||
// any address listed in trustedUpstreams (plain IPs or resolvable hostnames).
|
||||
// When a trusted datagram carries a v2 header its source address is replaced
|
||||
// with the address carried inside the header before further processing, so that
|
||||
// hole-punch endpoints reflect the original client IP rather than the load
|
||||
// balancer's address.
|
||||
func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string, proxyProtocol bool, trustedUpstreams []string) *UDPProxyServer {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
|
||||
trustedMap := make(map[string]struct{})
|
||||
for _, upstream := range trustedUpstreams {
|
||||
upstream = strings.TrimSpace(upstream)
|
||||
if upstream == "" {
|
||||
continue
|
||||
}
|
||||
trustedMap[upstream] = struct{}{}
|
||||
// Also resolve any hostnames to their current IPs so we can match by IP.
|
||||
if ips, err := net.LookupIP(upstream); err == nil {
|
||||
for _, ip := range ips {
|
||||
trustedMap[ip.String()] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &UDPProxyServer{
|
||||
addr: addr,
|
||||
serverURL: serverURL,
|
||||
privateKey: privateKey,
|
||||
packetChan: make(chan Packet, 50000), // Increased from 1000 to handle high throughput
|
||||
ReachableAt: reachableAt,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
addr: addr,
|
||||
serverURL: serverURL,
|
||||
privateKey: privateKey,
|
||||
packetChan: make(chan Packet, 50000), // Increased from 1000 to handle high throughput
|
||||
ReachableAt: reachableAt,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
proxyProtocol: proxyProtocol,
|
||||
trustedUpstreams: trustedMap,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,13 +321,48 @@ func (s *UDPProxyServer) readPackets() {
|
||||
// packetWorker processes incoming packets from the channel.
|
||||
func (s *UDPProxyServer) packetWorker() {
|
||||
for packet := range s.packetChan {
|
||||
// Determine packet type by inspecting the first byte.
|
||||
if packet.n > 0 && packet.data[0] >= 1 && packet.data[0] <= 4 {
|
||||
// effectiveData and effectiveAddr represent the application-layer payload
|
||||
// and the true originating address. They start as the raw UDP values and
|
||||
// may be updated below when a PROXY protocol v2 header is present.
|
||||
effectiveData := packet.data[:packet.n]
|
||||
effectiveAddr := packet.remoteAddr
|
||||
|
||||
// ---------- PROXY protocol v2 (UDP) ------------------------------------
|
||||
// If proxy protocol is enabled and this datagram arrives from a trusted
|
||||
// upstream (e.g. a load balancer), attempt to parse the v2 header so
|
||||
// that we use the original client address for hole-punch registration and
|
||||
// WireGuard session tracking rather than the load balancer's address.
|
||||
if s.proxyProtocol && len(s.trustedUpstreams) > 0 {
|
||||
remoteHost := packet.remoteAddr.IP.String()
|
||||
if _, trusted := s.trustedUpstreams[remoteHost]; trusted {
|
||||
if info, payload, ok := proxyproto.ParseV2UDPHeader(effectiveData); ok {
|
||||
if info != nil {
|
||||
// Override source address with what the proxy reported.
|
||||
if srcIP := net.ParseIP(info.SrcIP); srcIP != nil {
|
||||
effectiveAddr = &net.UDPAddr{
|
||||
IP: srcIP,
|
||||
Port: info.SrcPort,
|
||||
}
|
||||
logger.Debug("PROXY protocol v2: overriding source %s → %s:%d",
|
||||
packet.remoteAddr, info.SrcIP, info.SrcPort)
|
||||
}
|
||||
}
|
||||
// Always advance past the header so the remainder is treated
|
||||
// as the real application payload.
|
||||
effectiveData = payload
|
||||
}
|
||||
}
|
||||
}
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// Determine packet type by inspecting the first byte of the (possibly
|
||||
// stripped) application payload.
|
||||
if len(effectiveData) > 0 && effectiveData[0] >= 1 && effectiveData[0] <= 4 {
|
||||
// Process as a WireGuard packet.
|
||||
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
|
||||
s.handleWireGuardPacket(effectiveData, effectiveAddr)
|
||||
} else {
|
||||
// Rate limit: allow at most 2 hole punch messages per IP:Port per second
|
||||
rateLimitKey := packet.remoteAddr.String()
|
||||
rateLimitKey := effectiveAddr.String()
|
||||
entryVal, _ := s.holePunchRateLimiter.LoadOrStore(rateLimitKey, &holePunchRateLimitEntry{
|
||||
windowStart: time.Now(),
|
||||
})
|
||||
@@ -316,7 +384,7 @@ func (s *UDPProxyServer) packetWorker() {
|
||||
|
||||
// Process as an encrypted hole punch message
|
||||
var encMsg EncryptedHolePunchMessage
|
||||
if err := json.Unmarshal(packet.data, &encMsg); err != nil {
|
||||
if err := json.Unmarshal(effectiveData, &encMsg); err != nil {
|
||||
logger.Error("Error unmarshaling encrypted message: %v", err)
|
||||
// Return the buffer to the pool for reuse and continue with next packet
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
@@ -352,14 +420,14 @@ func (s *UDPProxyServer) packetWorker() {
|
||||
NewtID: msg.NewtID,
|
||||
OlmID: msg.OlmID,
|
||||
Token: msg.Token,
|
||||
IP: packet.remoteAddr.IP.String(),
|
||||
Port: packet.remoteAddr.Port,
|
||||
IP: effectiveAddr.IP.String(),
|
||||
Port: effectiveAddr.Port,
|
||||
Timestamp: time.Now().Unix(),
|
||||
ReachableAt: s.ReachableAt,
|
||||
ExitNodePublicKey: s.privateKey.PublicKey().String(),
|
||||
ClientPublicKey: msg.PublicKey,
|
||||
}
|
||||
logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", packet.remoteAddr.String(), endpoint.IP, endpoint.Port)
|
||||
logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", effectiveAddr.String(), endpoint.IP, endpoint.Port)
|
||||
s.notifyServer(endpoint)
|
||||
s.clearSessionsForIP(endpoint.IP) // Clear sessions for this IP to allow re-establishment
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user