diff --git a/main.go b/main.go index b7b8ba3..d6fcf73 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ import ( "time" "github.com/fosrl/gerbil/logger" + "github.com/fosrl/gerbil/proxy" "github.com/fosrl/gerbil/relay" "github.com/vishvananda/netlink" "golang.zx2c4.com/wireguard/wgctrl" @@ -32,7 +33,8 @@ var ( mu sync.Mutex wgMu sync.Mutex // Protects WireGuard operations notifyURL string - proxyServer *relay.UDPProxyServer + proxyRelay *relay.UDPProxyServer + proxySNI *proxy.SNIProxy ) type WgConfig struct { @@ -115,6 +117,10 @@ func main() { reachableAt string logLevel string mtu string + sniProxyPort int + localProxyAddr string + localProxyPort int + localOverridesStr string ) interfaceName = os.Getenv("INTERFACE") @@ -127,6 +133,11 @@ func main() { mtu = os.Getenv("MTU") notifyURL = os.Getenv("NOTIFY_URL") + sniProxyPortStr := os.Getenv("SNI_PORT") + localProxyAddr = os.Getenv("LOCAL_PROXY") + localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT") + localOverridesStr = os.Getenv("LOCAL_OVERRIDES") + if interfaceName == "" { flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") } @@ -159,6 +170,32 @@ func main() { if notifyURL == "" { flag.StringVar(¬ifyURL, "notify", "", "URL to notify on peer changes") } + + if sniProxyPortStr != "" { + if port, err := strconv.Atoi(sniProxyPortStr); err == nil { + sniProxyPort = port + } + } + if sniProxyPortStr == "" { + flag.IntVar(&sniProxyPort, "sni-port", 8443, "Port to listen on") + } + + if localProxyAddr == "" { + flag.StringVar(&localProxyAddr, "local-proxy", "localhost", "Local proxy address") + } + + if localProxyPortStr != "" { + if port, err := strconv.Atoi(localProxyPortStr); err == nil { + localProxyPort = port + } + } + if localProxyPortStr == "" { + flag.IntVar(&localProxyPort, "local-proxy-port", 9443, "Local proxy port") + } + if localOverridesStr != "" { + flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy") + } + flag.Parse() logger.Init() @@ -258,14 +295,26 @@ func main() { go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth") // Start the UDP proxy server - proxyServer = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt) - err = proxyServer.Start() + proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt) + err = proxyRelay.Start() if err != nil { logger.Fatal("Failed to start UDP proxy server: %v", err) } - defer proxyServer.Stop() + defer proxyRelay.Stop() - proxySNI, err := NewSNIProxy(*port, config.Sidecar.ExitNodeName, *localProxyAddr, *localProxyPort, localOverrides) + // TODO: WE SHOULD PULL THIS OUT OF THE CONFIG OR SOMETHING + // SO YOU DON'T NEED TO SET THIS SEPARATELY + // Parse local overrides + var localOverrides []string + if localOverridesStr != "" { + localOverrides = strings.Split(localOverridesStr, ",") + for i, domain := range localOverrides { + localOverrides[i] = strings.TrimSpace(domain) + } + logger.Info("Local overrides configured: %v", localOverrides) + } + + proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, "david", localProxyAddr, localProxyPort, localOverrides) if err != nil { logger.Fatal("Failed to create proxy: %v", err) } @@ -278,6 +327,7 @@ func main() { http.HandleFunc("/peer", handlePeer) http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping) http.HandleFunc("/update-destinations", handleUpdateDestinations) + http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs) logger.Info("Starting HTTP server on %s", listenAddr) // Run HTTP server in a goroutine @@ -656,9 +706,9 @@ func addPeerInternal(peer Peer) error { } // Clear relay connections for the peer's WireGuard IPs - if proxyServer != nil { + if proxyRelay != nil { for _, wgIP := range wgIPs { - proxyServer.OnPeerAdded(wgIP) + proxyRelay.OnPeerAdded(wgIP) } } @@ -701,7 +751,7 @@ func removePeerInternal(publicKey string) error { // Get current peer info before removing to clear relay connections var wgIPs []string - if proxyServer != nil { + if proxyRelay != nil { device, err := wgClient.Device(interfaceName) if err == nil { for _, peer := range device.Peers { @@ -730,9 +780,9 @@ func removePeerInternal(publicKey string) error { } // Clear relay connections for the peer's WireGuard IPs - if proxyServer != nil { + if proxyRelay != nil { for _, wgIP := range wgIPs { - proxyServer.OnPeerRemoved(wgIP) + proxyRelay.OnPeerRemoved(wgIP) } } @@ -769,13 +819,13 @@ func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) { } // Update the proxy mappings in the relay server - if proxyServer == nil { + if proxyRelay == nil { logger.Error("Proxy server is not available") http.Error(w, "Proxy server is not available", http.StatusInternalServerError) return } - updatedCount := proxyServer.UpdateDestinationInMappings(update.OldDestination, update.NewDestination) + updatedCount := proxyRelay.UpdateDestinationInMappings(update.OldDestination, update.NewDestination) logger.Info("Updated %d proxy mappings: %s:%d -> %s:%d", updatedCount, @@ -839,13 +889,13 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) { } // Update the proxy mappings in the relay server - if proxyServer == nil { + if proxyRelay == nil { logger.Error("Proxy server is not available") http.Error(w, "Proxy server is not available", http.StatusInternalServerError) return } - proxyServer.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations) + proxyRelay.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations) logger.Info("Updated proxy mapping for %s:%d with %d destinations", request.SourceIP, request.SourcePort, len(request.Destinations)) @@ -878,6 +928,12 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) { return } + proxySNI.UpdateLocalSNIs(req.FullDomains) + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "Local SNIs updated successfully", + }) } func periodicBandwidthCheck(endpoint string) { diff --git a/proxy/proxy.go b/proxy/proxy.go index 485a019..b553ff8 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -1,4 +1,4 @@ -package main +package proxy import ( "bytes" @@ -31,16 +31,16 @@ type RouteAPIResponse struct { // SNIProxy represents the main proxy server type SNIProxy struct { - port int - cache *cache.Cache - listener net.Listener - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup - exitNodeName string - localProxyAddr string - localProxyPort int - apiBaseURL string + port int + cache *cache.Cache + listener net.Listener + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + exitNodeName string + localProxyAddr string + localProxyPort int + remoteConfigURL string // New fields for fast local SNI lookup localSNIs map[string]struct{} @@ -73,7 +73,7 @@ func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } // NewSNIProxy creates a new SNI proxy instance -func NewSNIProxy(port int, exitNodeName, localProxyAddr string, localProxyPort int, apiBaseURL string, localOverrides []string) (*SNIProxy, error) { +func NewSNIProxy(port int, remoteConfigURL, exitNodeName, localProxyAddr string, localProxyPort int, localOverrides []string) (*SNIProxy, error) { ctx, cancel := context.WithCancel(context.Background()) // Create local overrides map @@ -85,17 +85,17 @@ func NewSNIProxy(port int, exitNodeName, localProxyAddr string, localProxyPort i } proxy := &SNIProxy{ - port: port, - cache: cache.New(3*time.Second, 10*time.Minute), - ctx: ctx, - cancel: cancel, - exitNodeName: exitNodeName, - localProxyAddr: localProxyAddr, - localProxyPort: localProxyPort, - apiBaseURL: apiBaseURL, - localSNIs: make(map[string]struct{}), - localOverrides: overridesMap, - activeTunnels: make(map[string]*activeTunnel), + port: port, + cache: cache.New(3*time.Second, 10*time.Minute), + ctx: ctx, + cancel: cancel, + exitNodeName: exitNodeName, + localProxyAddr: localProxyAddr, + localProxyPort: localProxyPort, + remoteConfigURL: remoteConfigURL, + localSNIs: make(map[string]struct{}), + localOverrides: overridesMap, + activeTunnels: make(map[string]*activeTunnel), } return proxy, nil @@ -337,7 +337,7 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { defer cancel() // Construct API URL - apiURL := fmt.Sprintf("%s/api/route/%s", p.apiBaseURL, hostname) + apiURL := fmt.Sprintf("%s/api/route/%s", p.remoteConfigURL, hostname) // Create HTTP request req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)