Integrate sni proxy

This commit is contained in:
Owen
2025-08-13 15:41:58 -07:00
parent f71f183886
commit 1df5eb19ff
2 changed files with 94 additions and 38 deletions

84
main.go
View File

@@ -18,6 +18,7 @@ import (
"time" "time"
"github.com/fosrl/gerbil/logger" "github.com/fosrl/gerbil/logger"
"github.com/fosrl/gerbil/proxy"
"github.com/fosrl/gerbil/relay" "github.com/fosrl/gerbil/relay"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
@@ -32,7 +33,8 @@ var (
mu sync.Mutex mu sync.Mutex
wgMu sync.Mutex // Protects WireGuard operations wgMu sync.Mutex // Protects WireGuard operations
notifyURL string notifyURL string
proxyServer *relay.UDPProxyServer proxyRelay *relay.UDPProxyServer
proxySNI *proxy.SNIProxy
) )
type WgConfig struct { type WgConfig struct {
@@ -115,6 +117,10 @@ func main() {
reachableAt string reachableAt string
logLevel string logLevel string
mtu string mtu string
sniProxyPort int
localProxyAddr string
localProxyPort int
localOverridesStr string
) )
interfaceName = os.Getenv("INTERFACE") interfaceName = os.Getenv("INTERFACE")
@@ -127,6 +133,11 @@ func main() {
mtu = os.Getenv("MTU") mtu = os.Getenv("MTU")
notifyURL = os.Getenv("NOTIFY_URL") notifyURL = os.Getenv("NOTIFY_URL")
sniProxyPortStr := os.Getenv("SNI_PORT")
localProxyAddr = os.Getenv("LOCAL_PROXY")
localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT")
localOverridesStr = os.Getenv("LOCAL_OVERRIDES")
if interfaceName == "" { if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
} }
@@ -159,6 +170,32 @@ func main() {
if notifyURL == "" { if notifyURL == "" {
flag.StringVar(&notifyURL, "notify", "", "URL to notify on peer changes") flag.StringVar(&notifyURL, "notify", "", "URL to notify on peer changes")
} }
if sniProxyPortStr != "" {
if port, err := strconv.Atoi(sniProxyPortStr); err == nil {
sniProxyPort = port
}
}
if sniProxyPortStr == "" {
flag.IntVar(&sniProxyPort, "sni-port", 8443, "Port to listen on")
}
if localProxyAddr == "" {
flag.StringVar(&localProxyAddr, "local-proxy", "localhost", "Local proxy address")
}
if localProxyPortStr != "" {
if port, err := strconv.Atoi(localProxyPortStr); err == nil {
localProxyPort = port
}
}
if localProxyPortStr == "" {
flag.IntVar(&localProxyPort, "local-proxy-port", 9443, "Local proxy port")
}
if localOverridesStr != "" {
flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy")
}
flag.Parse() flag.Parse()
logger.Init() logger.Init()
@@ -258,14 +295,26 @@ func main() {
go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth") go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth")
// Start the UDP proxy server // Start the UDP proxy server
proxyServer = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt) proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
err = proxyServer.Start() err = proxyRelay.Start()
if err != nil { if err != nil {
logger.Fatal("Failed to start UDP proxy server: %v", err) logger.Fatal("Failed to start UDP proxy server: %v", err)
} }
defer proxyServer.Stop() defer proxyRelay.Stop()
proxySNI, err := NewSNIProxy(*port, config.Sidecar.ExitNodeName, *localProxyAddr, *localProxyPort, localOverrides) // TODO: WE SHOULD PULL THIS OUT OF THE CONFIG OR SOMETHING
// SO YOU DON'T NEED TO SET THIS SEPARATELY
// Parse local overrides
var localOverrides []string
if localOverridesStr != "" {
localOverrides = strings.Split(localOverridesStr, ",")
for i, domain := range localOverrides {
localOverrides[i] = strings.TrimSpace(domain)
}
logger.Info("Local overrides configured: %v", localOverrides)
}
proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, "david", localProxyAddr, localProxyPort, localOverrides)
if err != nil { if err != nil {
logger.Fatal("Failed to create proxy: %v", err) logger.Fatal("Failed to create proxy: %v", err)
} }
@@ -278,6 +327,7 @@ func main() {
http.HandleFunc("/peer", handlePeer) http.HandleFunc("/peer", handlePeer)
http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping) http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping)
http.HandleFunc("/update-destinations", handleUpdateDestinations) http.HandleFunc("/update-destinations", handleUpdateDestinations)
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
logger.Info("Starting HTTP server on %s", listenAddr) logger.Info("Starting HTTP server on %s", listenAddr)
// Run HTTP server in a goroutine // Run HTTP server in a goroutine
@@ -656,9 +706,9 @@ func addPeerInternal(peer Peer) error {
} }
// Clear relay connections for the peer's WireGuard IPs // Clear relay connections for the peer's WireGuard IPs
if proxyServer != nil { if proxyRelay != nil {
for _, wgIP := range wgIPs { for _, wgIP := range wgIPs {
proxyServer.OnPeerAdded(wgIP) proxyRelay.OnPeerAdded(wgIP)
} }
} }
@@ -701,7 +751,7 @@ func removePeerInternal(publicKey string) error {
// Get current peer info before removing to clear relay connections // Get current peer info before removing to clear relay connections
var wgIPs []string var wgIPs []string
if proxyServer != nil { if proxyRelay != nil {
device, err := wgClient.Device(interfaceName) device, err := wgClient.Device(interfaceName)
if err == nil { if err == nil {
for _, peer := range device.Peers { for _, peer := range device.Peers {
@@ -730,9 +780,9 @@ func removePeerInternal(publicKey string) error {
} }
// Clear relay connections for the peer's WireGuard IPs // Clear relay connections for the peer's WireGuard IPs
if proxyServer != nil { if proxyRelay != nil {
for _, wgIP := range wgIPs { for _, wgIP := range wgIPs {
proxyServer.OnPeerRemoved(wgIP) proxyRelay.OnPeerRemoved(wgIP)
} }
} }
@@ -769,13 +819,13 @@ func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) {
} }
// Update the proxy mappings in the relay server // Update the proxy mappings in the relay server
if proxyServer == nil { if proxyRelay == nil {
logger.Error("Proxy server is not available") logger.Error("Proxy server is not available")
http.Error(w, "Proxy server is not available", http.StatusInternalServerError) http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
return return
} }
updatedCount := proxyServer.UpdateDestinationInMappings(update.OldDestination, update.NewDestination) updatedCount := proxyRelay.UpdateDestinationInMappings(update.OldDestination, update.NewDestination)
logger.Info("Updated %d proxy mappings: %s:%d -> %s:%d", logger.Info("Updated %d proxy mappings: %s:%d -> %s:%d",
updatedCount, updatedCount,
@@ -839,13 +889,13 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) {
} }
// Update the proxy mappings in the relay server // Update the proxy mappings in the relay server
if proxyServer == nil { if proxyRelay == nil {
logger.Error("Proxy server is not available") logger.Error("Proxy server is not available")
http.Error(w, "Proxy server is not available", http.StatusInternalServerError) http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
return return
} }
proxyServer.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations) proxyRelay.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations)
logger.Info("Updated proxy mapping for %s:%d with %d destinations", logger.Info("Updated proxy mapping for %s:%d with %d destinations",
request.SourceIP, request.SourcePort, len(request.Destinations)) request.SourceIP, request.SourcePort, len(request.Destinations))
@@ -878,6 +928,12 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) {
return return
} }
proxySNI.UpdateLocalSNIs(req.FullDomains)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "Local SNIs updated successfully",
})
} }
func periodicBandwidthCheck(endpoint string) { func periodicBandwidthCheck(endpoint string) {

View File

@@ -1,4 +1,4 @@
package main package proxy
import ( import (
"bytes" "bytes"
@@ -31,16 +31,16 @@ type RouteAPIResponse struct {
// SNIProxy represents the main proxy server // SNIProxy represents the main proxy server
type SNIProxy struct { type SNIProxy struct {
port int port int
cache *cache.Cache cache *cache.Cache
listener net.Listener listener net.Listener
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wg sync.WaitGroup wg sync.WaitGroup
exitNodeName string exitNodeName string
localProxyAddr string localProxyAddr string
localProxyPort int localProxyPort int
apiBaseURL string remoteConfigURL string
// New fields for fast local SNI lookup // New fields for fast local SNI lookup
localSNIs map[string]struct{} localSNIs map[string]struct{}
@@ -73,7 +73,7 @@ func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
// NewSNIProxy creates a new SNI proxy instance // NewSNIProxy creates a new SNI proxy instance
func NewSNIProxy(port int, exitNodeName, localProxyAddr string, localProxyPort int, apiBaseURL string, localOverrides []string) (*SNIProxy, error) { func NewSNIProxy(port int, remoteConfigURL, exitNodeName, localProxyAddr string, localProxyPort int, localOverrides []string) (*SNIProxy, error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
// Create local overrides map // Create local overrides map
@@ -85,17 +85,17 @@ func NewSNIProxy(port int, exitNodeName, localProxyAddr string, localProxyPort i
} }
proxy := &SNIProxy{ proxy := &SNIProxy{
port: port, port: port,
cache: cache.New(3*time.Second, 10*time.Minute), cache: cache.New(3*time.Second, 10*time.Minute),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
exitNodeName: exitNodeName, exitNodeName: exitNodeName,
localProxyAddr: localProxyAddr, localProxyAddr: localProxyAddr,
localProxyPort: localProxyPort, localProxyPort: localProxyPort,
apiBaseURL: apiBaseURL, remoteConfigURL: remoteConfigURL,
localSNIs: make(map[string]struct{}), localSNIs: make(map[string]struct{}),
localOverrides: overridesMap, localOverrides: overridesMap,
activeTunnels: make(map[string]*activeTunnel), activeTunnels: make(map[string]*activeTunnel),
} }
return proxy, nil return proxy, nil
@@ -337,7 +337,7 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) {
defer cancel() defer cancel()
// Construct API URL // Construct API URL
apiURL := fmt.Sprintf("%s/api/route/%s", p.apiBaseURL, hostname) apiURL := fmt.Sprintf("%s/api/route/%s", p.remoteConfigURL, hostname)
// Create HTTP request // Create HTTP request
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)