20 Commits
1.1.0 ... 1.2.1

Author SHA1 Message Date
Owen Schwartz
a8a0f92c9b Merge pull request #23 from fosrl/dev
Add proxy protocol
2025-08-27 14:22:08 -07:00
Owen
7040a9436e Add proxy protocol 2025-08-26 22:26:01 -07:00
Owen
04361242fe Update readme 2025-08-23 12:29:26 -07:00
Owen
554b1d55dc Merge branch 'main' into dev 2025-08-23 12:24:21 -07:00
Owen Schwartz
47589570c9 Merge pull request #20 from Lokowitz/sync-go-versions
update versions and sync go version in all files
2025-08-22 21:37:20 -07:00
Owen
9f5b8dea26 Merge branch 'hybrid' into dev 2025-08-22 11:56:58 -07:00
Owen
f6a1e1e27c Merge branch 'main' into dev 2025-08-22 11:56:54 -07:00
Owen
f983a8f141 Local proxy port 443 2025-08-22 11:56:29 -07:00
Owen
efce3cb0b2 Sni has no errors now 2025-08-17 10:43:37 -07:00
Marvin
6eeebd81b2 sync go versions 2025-08-17 11:48:39 +00:00
Owen
c970fd5a18 Update to work with multipe endpoints 2025-08-16 22:59:45 -07:00
Owen
09bd02456d Move to post 2025-08-16 22:53:49 -07:00
Owen
c24537af36 Fix url 2025-08-16 22:36:03 -07:00
Owen
9de3f14799 Update default config 2025-08-16 22:35:51 -07:00
Owen Schwartz
0908f75f5f Merge pull request #19 from fosrl/dependabot/docker/minor-updates-80a311fbba
Bump golang from 1.24.3-alpine to 1.25.0-alpine in the minor-updates group
2025-08-15 09:40:54 -07:00
Owen
10958f8c55 Use propper logger 2025-08-14 22:25:38 -07:00
dependabot[bot]
b1840fd5c3 Bump golang in the minor-updates group
Bumps the minor-updates group with 1 update: golang.


Updates `golang` from 1.24.3-alpine to 1.25.0-alpine

---
updated-dependencies:
- dependency-name: golang
  dependency-version: 1.25.0-alpine
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: minor-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-08-14 21:55:42 +00:00
Owen
1df5eb19ff Integrate sni proxy 2025-08-13 15:41:58 -07:00
Owen
f71f183886 Add basic proxy 2025-08-12 18:02:34 -07:00
Owen
8922ca9736 Fix some clients stuff for multi pop 2025-08-12 17:26:14 -07:00
12 changed files with 834 additions and 38 deletions

View File

@@ -33,3 +33,8 @@ updates:
minor-updates:
update-types:
- "minor"
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"

View File

@@ -12,16 +12,16 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v2
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
@@ -31,9 +31,9 @@ jobs:
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
- name: Install Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: 1.23.1
go-version: 1.25
- name: Build and push Docker images
run: |

View File

@@ -14,9 +14,9 @@ jobs:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: '1.23'
go-version: '1.25'
- name: Build go
run: go build

1
.go-version Normal file
View File

@@ -0,0 +1 @@
1.25

View File

@@ -1,4 +1,4 @@
FROM golang:1.24.3-alpine AS builder
FROM golang:1.25-alpine AS builder
# Set the working directory inside the container
WORKDIR /app

View File

@@ -24,7 +24,21 @@ Bytes transmitted in and out of each peer are collected every 10 seconds, and in
### Handle client relaying
Gerbil listens on port 21820 for incoming UDP hole punch packets to orchestrate NAT hole punching between olm and newt clients. Additionally, it handles relaying data through the gerbil server down to the newt. This is accomplished by scanning each packet for headers and handling them appropriately.
Gerbil listens on port 21820 for incoming UDP hole punch packets to orchestrate NAT hole punching between olm and newt clients. Additionally, it handles relaying data through the gerbil server down to the newt. This is accomplished by scanning each packet for headers and handling them appropriately.
### SNI Proxy
Gerbil includes an SNI (Server Name Indication) proxy that enables intelligent routing of HTTPS traffic between Pangolin nodes. When a TLS connection comes in, the proxy extracts the hostname from the SNI extension and queries Pangolin to determine the correct routing destination. This allows seamless routing of web traffic through the WireGuard mesh network:
- If the hostname is configured for local handling (via local overrides or local SNIs), traffic is routed to the local proxy
- Otherwise, the proxy queries Pangolin's routing API to determine which node should handle the traffic
- Supports caching of routing decisions to improve performance
- Handles connection pooling and graceful shutdown
- Optional PROXY protocol v1 support to preserve original client IP addresses when forwarding to downstream proxies (HAProxy, Nginx, etc.)
The PROXY protocol allows downstream proxies to know the real client IP address instead of seeing the SNI proxy's IP. When enabled with `--proxy-protocol`, the SNI proxy will prepend a PROXY protocol header to each connection containing the original client's IP and port information.
In single node (self hosted) Pangolin deployments this can be bypassed by using port 443:443 to route to Traefik instead of the SNI proxy at 8443.
## CLI Args
@@ -41,6 +55,11 @@ Note: You must use either `config` or `remoteConfig` to configure WireGuard.
- `log-level` (optional): The log level to use (DEBUG, INFO, WARN, ERROR, FATAL). Default: `INFO`
- `mtu` (optional): MTU of the WireGuard interface. Default: `1280`
- `notify` (optional): URL to notify on peer changes
- `sni-port` (optional): Port for the SNI proxy to listen on. Default: `8443`
- `local-proxy` (optional): Address for local proxy when routing local traffic. Default: `localhost`
- `local-proxy-port` (optional): Port for local proxy when routing local traffic. Default: `443`
- `local-overrides` (optional): Comma-separated list of domain names that should always be routed to the local proxy
- `proxy-protocol` (optional): Enable PROXY protocol v1 for preserving client IP addresses when forwarding to downstream proxies. Default: `false`
## Environment Variables
@@ -55,6 +74,11 @@ All CLI arguments can also be provided via environment variables:
- `LOG_LEVEL`: Log level (DEBUG, INFO, WARN, ERROR, FATAL)
- `MTU`: MTU of the WireGuard interface
- `NOTIFY_URL`: URL to notify on peer changes
- `SNI_PORT`: Port for the SNI proxy to listen on
- `LOCAL_PROXY`: Address for local proxy when routing local traffic
- `LOCAL_PROXY_PORT`: Port for local proxy when routing local traffic
- `LOCAL_OVERRIDES`: Comma-separated list of domain names that should always be routed to the local proxy
- `PROXY_PROTOCOL`: Enable PROXY protocol v1 for preserving client IP addresses (true/false)
Example:
@@ -62,8 +86,7 @@ Example:
./gerbil \
--reachableAt=http://gerbil:3003 \
--generateAndSaveKeyTo=/var/config/key \
--remoteConfig=http://pangolin:3001/api/v1/gerbil/get-config \
--reportBandwidthTo=http://pangolin:3001/api/v1/gerbil/receive-bandwidth
--remoteConfig=http://pangolin:3001/api/v1/
```
```yaml
@@ -75,8 +98,7 @@ services:
command:
- --reachableAt=http://gerbil:3003
- --generateAndSaveKeyTo=/var/config/key
- --remoteConfig=http://pangolin:3001/api/v1/gerbil/get-config
- --reportBandwidthTo=http://pangolin:3001/api/v1/gerbil/receive-bandwidth
- --remoteConfig=http://pangolin:3001/api/v1/
volumes:
- ./config/:/var/config
cap_add:
@@ -85,6 +107,7 @@ services:
ports:
- 51820:51820/udp
- 21820:21820/udp
- 443:8443/tcp # SNI proxy port
```
## Build

7
go.mod
View File

@@ -1,11 +1,10 @@
module github.com/fosrl/gerbil
go 1.23.1
toolchain go1.23.2
go 1.25
require (
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.36.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
)
@@ -15,8 +14,8 @@ require (
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.4.1 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.31.0 // indirect

2
go.sum
View File

@@ -10,6 +10,8 @@ github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=

120
main.go
View File

@@ -18,6 +18,7 @@ import (
"time"
"github.com/fosrl/gerbil/logger"
"github.com/fosrl/gerbil/proxy"
"github.com/fosrl/gerbil/relay"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/wgctrl"
@@ -32,7 +33,8 @@ var (
mu sync.Mutex
wgMu sync.Mutex // Protects WireGuard operations
notifyURL string
proxyServer *relay.UDPProxyServer
proxyRelay *relay.UDPProxyServer
proxySNI *proxy.SNIProxy
)
type WgConfig struct {
@@ -115,6 +117,11 @@ func main() {
reachableAt string
logLevel string
mtu string
sniProxyPort int
localProxyAddr string
localProxyPort int
localOverridesStr string
proxyProtocol bool
)
interfaceName = os.Getenv("INTERFACE")
@@ -127,6 +134,12 @@ func main() {
mtu = os.Getenv("MTU")
notifyURL = os.Getenv("NOTIFY_URL")
sniProxyPortStr := os.Getenv("SNI_PORT")
localProxyAddr = os.Getenv("LOCAL_PROXY")
localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT")
localOverridesStr = os.Getenv("LOCAL_OVERRIDES")
proxyProtocolStr := os.Getenv("PROXY_PROTOCOL")
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
}
@@ -159,6 +172,39 @@ func main() {
if notifyURL == "" {
flag.StringVar(&notifyURL, "notify", "", "URL to notify on peer changes")
}
if sniProxyPortStr != "" {
if port, err := strconv.Atoi(sniProxyPortStr); err == nil {
sniProxyPort = port
}
}
if sniProxyPortStr == "" {
flag.IntVar(&sniProxyPort, "sni-port", 8443, "Port to listen on")
}
if localProxyAddr == "" {
flag.StringVar(&localProxyAddr, "local-proxy", "localhost", "Local proxy address")
}
if localProxyPortStr != "" {
if port, err := strconv.Atoi(localProxyPortStr); err == nil {
localProxyPort = port
}
}
if localProxyPortStr == "" {
flag.IntVar(&localProxyPort, "local-proxy-port", 443, "Local proxy port")
}
if localOverridesStr != "" {
flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy")
}
if proxyProtocolStr != "" {
proxyProtocol = strings.ToLower(proxyProtocolStr) == "true"
}
if proxyProtocolStr == "" {
flag.BoolVar(&proxyProtocol, "proxy-protocol", true, "Enable PROXY protocol v1 for preserving client IP")
}
flag.Parse()
logger.Init()
@@ -258,17 +304,39 @@ func main() {
go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth")
// Start the UDP proxy server
proxyServer = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
err = proxyServer.Start()
proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
err = proxyRelay.Start()
if err != nil {
logger.Fatal("Failed to start UDP proxy server: %v", err)
}
defer proxyServer.Stop()
defer proxyRelay.Stop()
// TODO: WE SHOULD PULL THIS OUT OF THE CONFIG OR SOMETHING
// SO YOU DON'T NEED TO SET THIS SEPARATELY
// Parse local overrides
var localOverrides []string
if localOverridesStr != "" {
localOverrides = strings.Split(localOverridesStr, ",")
for i, domain := range localOverrides {
localOverrides[i] = strings.TrimSpace(domain)
}
logger.Info("Local overrides configured: %v", localOverrides)
}
proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides, proxyProtocol)
if err != nil {
logger.Fatal("Failed to create proxy: %v", err)
}
if err := proxySNI.Start(); err != nil {
logger.Fatal("Failed to start proxy: %v", err)
}
// Set up HTTP server
http.HandleFunc("/peer", handlePeer)
http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping)
http.HandleFunc("/update-destinations", handleUpdateDestinations)
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
logger.Info("Starting HTTP server on %s", listenAddr)
// Run HTTP server in a goroutine
@@ -647,9 +715,9 @@ func addPeerInternal(peer Peer) error {
}
// Clear relay connections for the peer's WireGuard IPs
if proxyServer != nil {
if proxyRelay != nil {
for _, wgIP := range wgIPs {
proxyServer.OnPeerAdded(wgIP)
proxyRelay.OnPeerAdded(wgIP)
}
}
@@ -692,7 +760,7 @@ func removePeerInternal(publicKey string) error {
// Get current peer info before removing to clear relay connections
var wgIPs []string
if proxyServer != nil {
if proxyRelay != nil {
device, err := wgClient.Device(interfaceName)
if err == nil {
for _, peer := range device.Peers {
@@ -721,9 +789,9 @@ func removePeerInternal(publicKey string) error {
}
// Clear relay connections for the peer's WireGuard IPs
if proxyServer != nil {
if proxyRelay != nil {
for _, wgIP := range wgIPs {
proxyServer.OnPeerRemoved(wgIP)
proxyRelay.OnPeerRemoved(wgIP)
}
}
@@ -760,13 +828,13 @@ func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) {
}
// Update the proxy mappings in the relay server
if proxyServer == nil {
if proxyRelay == nil {
logger.Error("Proxy server is not available")
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
return
}
updatedCount := proxyServer.UpdateDestinationInMappings(update.OldDestination, update.NewDestination)
updatedCount := proxyRelay.UpdateDestinationInMappings(update.OldDestination, update.NewDestination)
logger.Info("Updated %d proxy mappings: %s:%d -> %s:%d",
updatedCount,
@@ -830,13 +898,13 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) {
}
// Update the proxy mappings in the relay server
if proxyServer == nil {
if proxyRelay == nil {
logger.Error("Proxy server is not available")
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
return
}
proxyServer.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations)
proxyRelay.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations)
logger.Info("Updated proxy mapping for %s:%d with %d destinations",
request.SourceIP, request.SourcePort, len(request.Destinations))
@@ -851,6 +919,32 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) {
})
}
// UpdateLocalSNIsRequest represents the JSON payload for updating local SNIs
type UpdateLocalSNIsRequest struct {
FullDomains []string `json:"fullDomains"`
}
func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
logger.Error("Invalid method: %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req UpdateLocalSNIsRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
return
}
proxySNI.UpdateLocalSNIs(req.FullDomains)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "Local SNIs updated successfully",
})
}
func periodicBandwidthCheck(endpoint string) {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()

591
proxy/proxy.go Normal file
View File

@@ -0,0 +1,591 @@
package proxy
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"hash/fnv"
"io"
"log"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/fosrl/gerbil/logger"
"github.com/patrickmn/go-cache"
)
// RouteRecord represents a routing configuration
type RouteRecord struct {
Hostname string
TargetHost string
TargetPort int
}
// RouteAPIResponse represents the response from the route API
type RouteAPIResponse struct {
Endpoints []string `json:"endpoints"`
}
// SNIProxy represents the main proxy server
type SNIProxy struct {
port int
cache *cache.Cache
listener net.Listener
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
localProxyAddr string
localProxyPort int
remoteConfigURL string
publicKey string
proxyProtocol bool // Enable PROXY protocol v1
// New fields for fast local SNI lookup
localSNIs map[string]struct{}
localSNIsLock sync.RWMutex
// Local overrides for domains that should always use local proxy
localOverrides map[string]struct{}
// Track active tunnels by SNI
activeTunnels map[string]*activeTunnel
activeTunnelsLock sync.Mutex
}
type activeTunnel struct {
conns []net.Conn
}
// readOnlyConn is a wrapper for io.Reader that implements net.Conn
type readOnlyConn struct {
reader io.Reader
}
func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) }
func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
func (conn readOnlyConn) Close() error { return nil }
func (conn readOnlyConn) LocalAddr() net.Addr { return nil }
func (conn readOnlyConn) RemoteAddr() net.Addr { return nil }
func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
// buildProxyProtocolHeader creates a PROXY protocol v1 header
func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
clientTCP, ok := clientAddr.(*net.TCPAddr)
if !ok {
// Fallback for unknown address types
return "PROXY UNKNOWN\r\n"
}
targetTCP, ok := targetAddr.(*net.TCPAddr)
if !ok {
// Fallback for unknown address types
return "PROXY UNKNOWN\r\n"
}
// Determine protocol family based on client IP and normalize target IP accordingly
var protocol string
var targetIP string
if clientTCP.IP.To4() != nil {
// Client is IPv4, use TCP4 protocol
protocol = "TCP4"
if targetTCP.IP.To4() != nil {
// Target is also IPv4, use as-is
targetIP = targetTCP.IP.String()
} else {
// Target is IPv6, but we need IPv4 for consistent protocol family
// Use the IPv4 loopback if target is IPv6 loopback, otherwise use 127.0.0.1
if targetTCP.IP.IsLoopback() {
targetIP = "127.0.0.1"
} else {
// For non-loopback IPv6 targets, we could try to extract embedded IPv4
// or fall back to a sensible IPv4 address based on the target
targetIP = "127.0.0.1" // Safe fallback
}
}
} else {
// Client is IPv6, use TCP6 protocol
protocol = "TCP6"
if targetTCP.IP.To4() != nil {
// Target is IPv4, convert to IPv6 representation
targetIP = "::ffff:" + targetTCP.IP.String()
} else {
// Target is also IPv6, use as-is
targetIP = targetTCP.IP.String()
}
}
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
protocol,
clientTCP.IP.String(),
targetIP,
clientTCP.Port,
targetTCP.Port)
}
// NewSNIProxy creates a new SNI proxy instance
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool) (*SNIProxy, error) {
ctx, cancel := context.WithCancel(context.Background())
// Create local overrides map
overridesMap := make(map[string]struct{})
for _, domain := range localOverrides {
if domain != "" {
overridesMap[domain] = struct{}{}
}
}
proxy := &SNIProxy{
port: port,
cache: cache.New(3*time.Second, 10*time.Minute),
ctx: ctx,
cancel: cancel,
localProxyAddr: localProxyAddr,
localProxyPort: localProxyPort,
remoteConfigURL: remoteConfigURL,
publicKey: publicKey,
proxyProtocol: proxyProtocol,
localSNIs: make(map[string]struct{}),
localOverrides: overridesMap,
activeTunnels: make(map[string]*activeTunnel),
}
return proxy, nil
}
// Start begins listening for connections
func (p *SNIProxy) Start() error {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", p.port))
if err != nil {
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
}
p.listener = listener
logger.Debug("SNI Proxy listening on port %d", p.port)
// Accept connections in a goroutine
go p.acceptConnections()
return nil
}
// Stop gracefully shuts down the proxy
func (p *SNIProxy) Stop() error {
log.Println("Stopping SNI Proxy...")
p.cancel()
if p.listener != nil {
p.listener.Close()
}
// Wait for all goroutines to finish with timeout
done := make(chan struct{})
go func() {
p.wg.Wait()
close(done)
}()
select {
case <-done:
log.Println("All connections closed gracefully")
case <-time.After(30 * time.Second):
log.Println("Timeout waiting for connections to close")
}
log.Println("SNI Proxy stopped")
return nil
}
// acceptConnections handles incoming connections
func (p *SNIProxy) acceptConnections() {
for {
conn, err := p.listener.Accept()
if err != nil {
select {
case <-p.ctx.Done():
return
default:
logger.Debug("Accept error: %v", err)
continue
}
}
p.wg.Add(1)
go p.handleConnection(conn)
}
}
// readClientHello reads and parses the TLS ClientHello message
func (p *SNIProxy) readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
var hello *tls.ClientHelloInfo
err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
hello = new(tls.ClientHelloInfo)
*hello = *argHello
return nil, nil
},
}).Handshake()
if hello == nil {
return nil, err
}
return hello, nil
}
// peekClientHello reads the ClientHello while preserving the data for forwarding
func (p *SNIProxy) peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) {
peekedBytes := new(bytes.Buffer)
hello, err := p.readClientHello(io.TeeReader(reader, peekedBytes))
if err != nil {
return nil, nil, err
}
return hello, io.MultiReader(peekedBytes, reader), nil
}
// extractSNI extracts the SNI hostname from the TLS ClientHello
func (p *SNIProxy) extractSNI(conn net.Conn) (string, io.Reader, error) {
clientHello, clientReader, err := p.peekClientHello(conn)
if err != nil {
return "", nil, fmt.Errorf("failed to peek ClientHello: %w", err)
}
if clientHello.ServerName == "" {
return "", clientReader, fmt.Errorf("no SNI hostname found in ClientHello")
}
return clientHello.ServerName, clientReader, nil
}
// handleConnection processes a single client connection
func (p *SNIProxy) handleConnection(clientConn net.Conn) {
defer p.wg.Done()
defer clientConn.Close()
logger.Debug("Accepted connection from %s", clientConn.RemoteAddr())
// Set read timeout for SNI extraction
if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
logger.Debug("Failed to set read deadline: %v", err)
return
}
// Extract SNI hostname
hostname, clientReader, err := p.extractSNI(clientConn)
if err != nil {
logger.Debug("SNI extraction failed: %v", err)
return
}
if hostname == "" {
log.Println("No SNI hostname found")
return
}
logger.Debug("SNI hostname detected: %s", hostname)
// Remove read timeout for normal operation
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
logger.Debug("Failed to clear read deadline: %v", err)
return
}
// Get routing information
route, err := p.getRoute(hostname, clientConn.RemoteAddr().String())
if err != nil {
logger.Debug("Failed to get route for %s: %v", hostname, err)
return
}
if route == nil {
logger.Debug("No route found for hostname: %s", hostname)
return
}
logger.Debug("Routing %s to %s:%d", hostname, route.TargetHost, route.TargetPort)
// Connect to target server
targetConn, err := net.DialTimeout("tcp",
fmt.Sprintf("%s:%d", route.TargetHost, route.TargetPort),
10*time.Second)
if err != nil {
logger.Debug("Failed to connect to target %s:%d: %v",
route.TargetHost, route.TargetPort, err)
return
}
defer targetConn.Close()
logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort)
// Send PROXY protocol header if enabled
if p.proxyProtocol {
proxyHeader := buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr())
logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader))
if _, err := targetConn.Write([]byte(proxyHeader)); err != nil {
logger.Debug("Failed to send PROXY protocol header: %v", err)
return
}
}
// Track this tunnel by SNI
p.activeTunnelsLock.Lock()
tunnel, ok := p.activeTunnels[hostname]
if !ok {
tunnel = &activeTunnel{}
p.activeTunnels[hostname] = tunnel
}
tunnel.conns = append(tunnel.conns, clientConn)
p.activeTunnelsLock.Unlock()
defer func() {
// Remove this conn from active tunnels
p.activeTunnelsLock.Lock()
if tunnel, ok := p.activeTunnels[hostname]; ok {
newConns := make([]net.Conn, 0, len(tunnel.conns))
for _, c := range tunnel.conns {
if c != clientConn {
newConns = append(newConns, c)
}
}
if len(newConns) == 0 {
delete(p.activeTunnels, hostname)
} else {
tunnel.conns = newConns
}
}
p.activeTunnelsLock.Unlock()
}()
// Start bidirectional data transfer
p.pipe(clientConn, targetConn, clientReader)
}
// getRoute retrieves routing information for a hostname
func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
// Check local overrides first
if _, isOverride := p.localOverrides[hostname]; isOverride {
logger.Debug("Local override matched for hostname: %s", hostname)
return &RouteRecord{
Hostname: hostname,
TargetHost: p.localProxyAddr,
TargetPort: p.localProxyPort,
}, nil
}
// Fast path: check if hostname is in localSNIs
p.localSNIsLock.RLock()
_, isLocal := p.localSNIs[hostname]
p.localSNIsLock.RUnlock()
if isLocal {
return &RouteRecord{
Hostname: hostname,
TargetHost: p.localProxyAddr,
TargetPort: p.localProxyPort,
}, nil
}
// Check cache first
if cached, found := p.cache.Get(hostname); found {
if cached == nil {
return nil, nil // Cached negative result
}
logger.Debug("Cache hit for hostname: %s", hostname)
return cached.(*RouteRecord), nil
}
logger.Debug("Cache miss for hostname: %s, querying API", hostname)
// Query API with timeout
ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second)
defer cancel()
// Construct API URL (without hostname in path)
apiURL := fmt.Sprintf("%s/gerbil/get-resolved-hostname", p.remoteConfigURL)
// Create request body with hostname and public key
requestBody := map[string]string{
"hostname": hostname,
"publicKey": p.publicKey,
}
jsonBody, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
// Make HTTP request
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("API request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
// Cache negative result for shorter time (1 minute)
p.cache.Set(hostname, nil, 1*time.Minute)
return nil, nil
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API returned status %d", resp.StatusCode)
}
// Parse response
var apiResponse RouteAPIResponse
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
return nil, fmt.Errorf("failed to decode API response: %w", err)
}
endpoints := apiResponse.Endpoints
// Default target configuration
targetHost := p.localProxyAddr
targetPort := p.localProxyPort
// If no endpoints returned, use local node
if len(endpoints) == 0 {
logger.Debug("No endpoints returned for hostname: %s, using local node", hostname)
} else {
// Select endpoint using consistent hashing for stickiness
selectedEndpoint := p.selectStickyEndpoint(clientAddr, endpoints)
targetHost = selectedEndpoint
targetPort = 443 // Default HTTPS port
logger.Debug("Selected endpoint %s for hostname %s from client %s", selectedEndpoint, hostname, clientAddr)
}
route := &RouteRecord{
Hostname: hostname,
TargetHost: targetHost,
TargetPort: targetPort,
}
// Cache the result
p.cache.Set(hostname, route, cache.DefaultExpiration)
logger.Debug("Cached route for hostname: %s", hostname)
return route, nil
}
// selectStickyEndpoint selects an endpoint using consistent hashing to ensure
// the same client always routes to the same endpoint for load balancing
func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) string {
if len(endpoints) == 0 {
return p.localProxyAddr
}
if len(endpoints) == 1 {
return endpoints[0]
}
// Use FNV hash for consistent selection based on client address
hash := fnv.New32a()
hash.Write([]byte(clientAddr))
index := hash.Sum32() % uint32(len(endpoints))
return endpoints[index]
}
// pipe handles bidirectional data transfer between connections
func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) {
var wg sync.WaitGroup
wg.Add(2)
// Copy data from client to target (using the buffered reader)
go func() {
defer wg.Done()
defer func() {
if tcpConn, ok := targetConn.(*net.TCPConn); ok {
tcpConn.CloseWrite()
}
}()
// Use a large buffer for better performance
buf := make([]byte, 32*1024)
_, err := io.CopyBuffer(targetConn, clientReader, buf)
if err != nil && err != io.EOF {
logger.Debug("Copy client->target error: %v", err)
}
}()
// Copy data from target to client
go func() {
defer wg.Done()
defer func() {
if tcpConn, ok := clientConn.(*net.TCPConn); ok {
tcpConn.CloseWrite()
}
}()
// Use a large buffer for better performance
buf := make([]byte, 32*1024)
_, err := io.CopyBuffer(clientConn, targetConn, buf)
if err != nil && err != io.EOF {
logger.Debug("Copy target->client error: %v", err)
}
}()
wg.Wait()
}
// GetCacheStats returns cache statistics
func (p *SNIProxy) GetCacheStats() (int, int) {
return p.cache.ItemCount(), len(p.cache.Items())
}
// ClearCache clears all cached entries
func (p *SNIProxy) ClearCache() {
p.cache.Flush()
log.Println("Cache cleared")
}
// UpdateLocalSNIs updates the local SNIs and invalidates cache for changed domains
func (p *SNIProxy) UpdateLocalSNIs(fullDomains []string) {
newSNIs := make(map[string]struct{})
for _, domain := range fullDomains {
newSNIs[domain] = struct{}{}
// Invalidate any cached route for this domain
p.cache.Delete(domain)
}
// Update localSNIs
p.localSNIsLock.Lock()
removed := make([]string, 0)
for sni := range p.localSNIs {
if _, stillLocal := newSNIs[sni]; !stillLocal {
removed = append(removed, sni)
}
}
p.localSNIs = newSNIs
p.localSNIsLock.Unlock()
logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed))
// Terminate tunnels for removed SNIs
if len(removed) > 0 {
p.activeTunnelsLock.Lock()
for _, sni := range removed {
if tunnels, ok := p.activeTunnels[sni]; ok {
for _, conn := range tunnels.conns {
conn.Close()
}
delete(p.activeTunnels, sni)
logger.Debug("Closed tunnels for SNI target change: %s", sni)
}
}
p.activeTunnelsLock.Unlock()
}
}

78
proxy/proxy_test.go Normal file
View File

@@ -0,0 +1,78 @@
package proxy
import (
"net"
"testing"
)
func TestBuildProxyProtocolHeader(t *testing.T) {
tests := []struct {
name string
clientAddr string
targetAddr string
expected string
}{
{
name: "IPv4 client and target",
clientAddr: "192.168.1.100:12345",
targetAddr: "10.0.0.1:443",
expected: "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n",
},
{
name: "IPv6 client and target",
clientAddr: "[2001:db8::1]:12345",
targetAddr: "[2001:db8::2]:443",
expected: "PROXY TCP6 2001:db8::1 2001:db8::2 12345 443\r\n",
},
{
name: "IPv4 client with IPv6 loopback target",
clientAddr: "192.168.1.100:12345",
targetAddr: "[::1]:443",
expected: "PROXY TCP4 192.168.1.100 127.0.0.1 12345 443\r\n",
},
{
name: "IPv4 client with IPv6 target",
clientAddr: "192.168.1.100:12345",
targetAddr: "[2001:db8::2]:443",
expected: "PROXY TCP4 192.168.1.100 127.0.0.1 12345 443\r\n",
},
{
name: "IPv6 client with IPv4 target",
clientAddr: "[2001:db8::1]:12345",
targetAddr: "10.0.0.1:443",
expected: "PROXY TCP6 2001:db8::1 ::ffff:10.0.0.1 12345 443\r\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientTCP, err := net.ResolveTCPAddr("tcp", tt.clientAddr)
if err != nil {
t.Fatalf("Failed to resolve client address: %v", err)
}
targetTCP, err := net.ResolveTCPAddr("tcp", tt.targetAddr)
if err != nil {
t.Fatalf("Failed to resolve target address: %v", err)
}
result := buildProxyProtocolHeader(clientTCP, targetTCP)
if result != tt.expected {
t.Errorf("Expected %q, got %q", tt.expected, result)
}
})
}
}
func TestBuildProxyProtocolHeaderUnknownType(t *testing.T) {
// Test with non-TCP address type
clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345}
targetAddr := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443}
result := buildProxyProtocolHeader(clientAddr, targetAddr)
expected := "PROXY UNKNOWN\r\n"
if result != expected {
t.Errorf("Expected %q, got %q", expected, result)
}
}

View File

@@ -37,6 +37,7 @@ type ClientEndpoint struct {
Port int `json:"port"`
Timestamp int64 `json:"timestamp"`
ReachableAt string `json:"reachableAt"`
PublicKey string `json:"publicKey"`
}
// Updated to support multiple destination peers
@@ -225,9 +226,11 @@ func (s *UDPProxyServer) packetWorker() {
Port: packet.remoteAddr.Port,
Timestamp: time.Now().Unix(),
ReachableAt: s.ReachableAt,
PublicKey: s.privateKey.PublicKey().String(),
}
logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", packet.remoteAddr.String(), endpoint.IP, endpoint.Port)
s.notifyServer(endpoint)
s.clearSessionsForIP(endpoint.IP) // Clear sessions for this IP to allow re-establishment
}
// Return the buffer to the pool for reuse.
bufferPool.Put(packet.data[:1500])
@@ -355,7 +358,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
switch messageType {
case WireGuardMessageTypeHandshakeInitiation:
// Initial handshake: forward to all peers
logger.Debug("Forwarding handshake initiation from %s (sender index: %d)", remoteAddr, senderIndex)
logger.Debug("Forwarding handshake initiation from %s (sender index: %d) to peers %v", remoteAddr, senderIndex, proxyMapping.Destinations)
for _, dest := range proxyMapping.Destinations {
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
@@ -414,7 +417,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
case WireGuardMessageTypeTransportData:
// Data packet: forward only to the established session peer
logger.Debug("Received transport data with receiver index %d from %s", receiverIndex, remoteAddr)
// logger.Debug("Received transport data with receiver index %d from %s", receiverIndex, remoteAddr)
// Look up the session based on the receiver index
var destAddr *net.UDPAddr
@@ -660,7 +663,7 @@ func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, des
func (s *UDPProxyServer) OnPeerAdded(wgIP string) {
logger.Info("Clearing connections for added peer with WG IP: %s", wgIP)
s.clearConnectionsForWGIP(wgIP)
s.clearSessionsForWGIP(wgIP)
// s.clearSessionsForWGIP(wgIP) THE DEST ADDR IS NOT THE WG IP, SO THIS IS NOT NEEDED
// s.clearProxyMappingsForWGIP(wgIP)
}
@@ -668,7 +671,7 @@ func (s *UDPProxyServer) OnPeerAdded(wgIP string) {
func (s *UDPProxyServer) OnPeerRemoved(wgIP string) {
logger.Info("Clearing connections for removed peer with WG IP: %s", wgIP)
s.clearConnectionsForWGIP(wgIP)
s.clearSessionsForWGIP(wgIP)
// s.clearSessionsForWGIP(wgIP) THE DEST ADDR IS NOT THE WG IP, SO THIS IS NOT NEEDED
// s.clearProxyMappingsForWGIP(wgIP)
}
@@ -699,7 +702,7 @@ func (s *UDPProxyServer) clearConnectionsForWGIP(wgIP string) {
}
// clearSessionsForWGIP removes all WireGuard sessions associated with a specific WireGuard IP
func (s *UDPProxyServer) clearSessionsForWGIP(wgIP string) {
func (s *UDPProxyServer) clearSessionsForIP(ip string) {
var keysToDelete []string
s.wgSessions.Range(func(key, value interface{}) bool {
@@ -707,9 +710,9 @@ func (s *UDPProxyServer) clearSessionsForWGIP(wgIP string) {
session := value.(*WireGuardSession)
// Check if the session's destination address contains the WG IP
if session.DestAddr != nil && session.DestAddr.IP.String() == wgIP {
if session.DestAddr != nil && session.DestAddr.IP.String() == ip {
keysToDelete = append(keysToDelete, keyStr)
logger.Debug("Marking session for deletion for WG IP %s: %s", wgIP, keyStr)
logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr)
}
return true
})
@@ -719,7 +722,7 @@ func (s *UDPProxyServer) clearSessionsForWGIP(wgIP string) {
s.wgSessions.Delete(key)
}
logger.Info("Cleared %d sessions for WG IP: %s", len(keysToDelete), wgIP)
logger.Info("Cleared %d sessions for WG IP: %s", len(keysToDelete), ip)
}
// // clearProxyMappingsForWGIP removes all proxy mappings that have destinations pointing to a specific WireGuard IP