Integrate sni proxy

This commit is contained in:
Owen
2025-08-13 15:41:58 -07:00
parent f71f183886
commit 1df5eb19ff
2 changed files with 94 additions and 38 deletions

84
main.go
View File

@@ -18,6 +18,7 @@ import (
"time"
"github.com/fosrl/gerbil/logger"
"github.com/fosrl/gerbil/proxy"
"github.com/fosrl/gerbil/relay"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/wgctrl"
@@ -32,7 +33,8 @@ var (
mu sync.Mutex
wgMu sync.Mutex // Protects WireGuard operations
notifyURL string
proxyServer *relay.UDPProxyServer
proxyRelay *relay.UDPProxyServer
proxySNI *proxy.SNIProxy
)
type WgConfig struct {
@@ -115,6 +117,10 @@ func main() {
reachableAt string
logLevel string
mtu string
sniProxyPort int
localProxyAddr string
localProxyPort int
localOverridesStr string
)
interfaceName = os.Getenv("INTERFACE")
@@ -127,6 +133,11 @@ func main() {
mtu = os.Getenv("MTU")
notifyURL = os.Getenv("NOTIFY_URL")
sniProxyPortStr := os.Getenv("SNI_PORT")
localProxyAddr = os.Getenv("LOCAL_PROXY")
localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT")
localOverridesStr = os.Getenv("LOCAL_OVERRIDES")
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
}
@@ -159,6 +170,32 @@ func main() {
if notifyURL == "" {
flag.StringVar(&notifyURL, "notify", "", "URL to notify on peer changes")
}
if sniProxyPortStr != "" {
if port, err := strconv.Atoi(sniProxyPortStr); err == nil {
sniProxyPort = port
}
}
if sniProxyPortStr == "" {
flag.IntVar(&sniProxyPort, "sni-port", 8443, "Port to listen on")
}
if localProxyAddr == "" {
flag.StringVar(&localProxyAddr, "local-proxy", "localhost", "Local proxy address")
}
if localProxyPortStr != "" {
if port, err := strconv.Atoi(localProxyPortStr); err == nil {
localProxyPort = port
}
}
if localProxyPortStr == "" {
flag.IntVar(&localProxyPort, "local-proxy-port", 9443, "Local proxy port")
}
if localOverridesStr != "" {
flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy")
}
flag.Parse()
logger.Init()
@@ -258,14 +295,26 @@ func main() {
go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth")
// Start the UDP proxy server
proxyServer = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
err = proxyServer.Start()
proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
err = proxyRelay.Start()
if err != nil {
logger.Fatal("Failed to start UDP proxy server: %v", err)
}
defer proxyServer.Stop()
defer proxyRelay.Stop()
proxySNI, err := NewSNIProxy(*port, config.Sidecar.ExitNodeName, *localProxyAddr, *localProxyPort, localOverrides)
// TODO: WE SHOULD PULL THIS OUT OF THE CONFIG OR SOMETHING
// SO YOU DON'T NEED TO SET THIS SEPARATELY
// Parse local overrides
var localOverrides []string
if localOverridesStr != "" {
localOverrides = strings.Split(localOverridesStr, ",")
for i, domain := range localOverrides {
localOverrides[i] = strings.TrimSpace(domain)
}
logger.Info("Local overrides configured: %v", localOverrides)
}
proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, "david", localProxyAddr, localProxyPort, localOverrides)
if err != nil {
logger.Fatal("Failed to create proxy: %v", err)
}
@@ -278,6 +327,7 @@ func main() {
http.HandleFunc("/peer", handlePeer)
http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping)
http.HandleFunc("/update-destinations", handleUpdateDestinations)
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
logger.Info("Starting HTTP server on %s", listenAddr)
// Run HTTP server in a goroutine
@@ -656,9 +706,9 @@ func addPeerInternal(peer Peer) error {
}
// Clear relay connections for the peer's WireGuard IPs
if proxyServer != nil {
if proxyRelay != nil {
for _, wgIP := range wgIPs {
proxyServer.OnPeerAdded(wgIP)
proxyRelay.OnPeerAdded(wgIP)
}
}
@@ -701,7 +751,7 @@ func removePeerInternal(publicKey string) error {
// Get current peer info before removing to clear relay connections
var wgIPs []string
if proxyServer != nil {
if proxyRelay != nil {
device, err := wgClient.Device(interfaceName)
if err == nil {
for _, peer := range device.Peers {
@@ -730,9 +780,9 @@ func removePeerInternal(publicKey string) error {
}
// Clear relay connections for the peer's WireGuard IPs
if proxyServer != nil {
if proxyRelay != nil {
for _, wgIP := range wgIPs {
proxyServer.OnPeerRemoved(wgIP)
proxyRelay.OnPeerRemoved(wgIP)
}
}
@@ -769,13 +819,13 @@ func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) {
}
// Update the proxy mappings in the relay server
if proxyServer == nil {
if proxyRelay == nil {
logger.Error("Proxy server is not available")
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
return
}
updatedCount := proxyServer.UpdateDestinationInMappings(update.OldDestination, update.NewDestination)
updatedCount := proxyRelay.UpdateDestinationInMappings(update.OldDestination, update.NewDestination)
logger.Info("Updated %d proxy mappings: %s:%d -> %s:%d",
updatedCount,
@@ -839,13 +889,13 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) {
}
// Update the proxy mappings in the relay server
if proxyServer == nil {
if proxyRelay == nil {
logger.Error("Proxy server is not available")
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
return
}
proxyServer.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations)
proxyRelay.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations)
logger.Info("Updated proxy mapping for %s:%d with %d destinations",
request.SourceIP, request.SourcePort, len(request.Destinations))
@@ -878,6 +928,12 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) {
return
}
proxySNI.UpdateLocalSNIs(req.FullDomains)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "Local SNIs updated successfully",
})
}
func periodicBandwidthCheck(endpoint string) {

View File

@@ -1,4 +1,4 @@
package main
package proxy
import (
"bytes"
@@ -40,7 +40,7 @@ type SNIProxy struct {
exitNodeName string
localProxyAddr string
localProxyPort int
apiBaseURL string
remoteConfigURL string
// New fields for fast local SNI lookup
localSNIs map[string]struct{}
@@ -73,7 +73,7 @@ func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
// NewSNIProxy creates a new SNI proxy instance
func NewSNIProxy(port int, exitNodeName, localProxyAddr string, localProxyPort int, apiBaseURL string, localOverrides []string) (*SNIProxy, error) {
func NewSNIProxy(port int, remoteConfigURL, exitNodeName, localProxyAddr string, localProxyPort int, localOverrides []string) (*SNIProxy, error) {
ctx, cancel := context.WithCancel(context.Background())
// Create local overrides map
@@ -92,7 +92,7 @@ func NewSNIProxy(port int, exitNodeName, localProxyAddr string, localProxyPort i
exitNodeName: exitNodeName,
localProxyAddr: localProxyAddr,
localProxyPort: localProxyPort,
apiBaseURL: apiBaseURL,
remoteConfigURL: remoteConfigURL,
localSNIs: make(map[string]struct{}),
localOverrides: overridesMap,
activeTunnels: make(map[string]*activeTunnel),
@@ -337,7 +337,7 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) {
defer cancel()
// Construct API URL
apiURL := fmt.Sprintf("%s/api/route/%s", p.apiBaseURL, hostname)
apiURL := fmt.Sprintf("%s/api/route/%s", p.remoteConfigURL, hostname)
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)