From f71f18388686f27ee2850165fec03b5a6a158903 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 12 Aug 2025 18:02:34 -0700 Subject: [PATCH 1/9] Add basic proxy --- go.mod | 1 + go.sum | 2 + main.go | 29 +++ proxy/proxy.go | 483 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 515 insertions(+) create mode 100644 proxy/proxy.go diff --git a/go.mod b/go.mod index 46d796e..c17865a 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.4.1 // indirect + github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/vishvananda/netns v0.0.5 // indirect golang.org/x/crypto v0.36.0 // indirect golang.org/x/net v0.38.0 // indirect diff --git a/go.sum b/go.sum index 9a60eee..a089b8e 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= diff --git a/main.go b/main.go index e2d2c16..b7b8ba3 100644 --- a/main.go +++ b/main.go @@ -265,6 +265,15 @@ func main() { } defer proxyServer.Stop() + proxySNI, err := NewSNIProxy(*port, config.Sidecar.ExitNodeName, *localProxyAddr, *localProxyPort, localOverrides) + if err != nil { + logger.Fatal("Failed to create proxy: %v", err) + } + + if err := proxySNI.Start(); err != nil { + logger.Fatal("Failed to start proxy: %v", err) + } + // Set up HTTP server http.HandleFunc("/peer", handlePeer) http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping) @@ -851,6 +860,26 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) { }) } +// UpdateLocalSNIsRequest represents the JSON payload for updating local SNIs +type UpdateLocalSNIsRequest struct { + FullDomains []string `json:"fullDomains"` +} + +func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + logger.Error("Invalid method: %s", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req UpdateLocalSNIsRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid JSON payload", http.StatusBadRequest) + return + } + +} + func periodicBandwidthCheck(endpoint string) { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() diff --git a/proxy/proxy.go b/proxy/proxy.go new file mode 100644 index 0000000..485a019 --- /dev/null +++ b/proxy/proxy.go @@ -0,0 +1,483 @@ +package main + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "log" + "net" + "net/http" + "sync" + "time" + + "github.com/patrickmn/go-cache" +) + +// RouteRecord represents a routing configuration +type RouteRecord struct { + Hostname string + TargetHost string + TargetPort int +} + +// RouteAPIResponse represents the response from the route API +type RouteAPIResponse struct { + Endpoint string `json:"endpoint"` + Name string `json:"name"` +} + +// SNIProxy represents the main proxy server +type SNIProxy struct { + port int + cache *cache.Cache + listener net.Listener + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + exitNodeName string + localProxyAddr string + localProxyPort int + apiBaseURL string + + // New fields for fast local SNI lookup + localSNIs map[string]struct{} + localSNIsLock sync.RWMutex + + // Local overrides for domains that should always use local proxy + localOverrides map[string]struct{} + + // Track active tunnels by SNI + activeTunnels map[string]*activeTunnel + activeTunnelsLock sync.Mutex +} + +type activeTunnel struct { + conns []net.Conn +} + +// readOnlyConn is a wrapper for io.Reader that implements net.Conn +type readOnlyConn struct { + reader io.Reader +} + +func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) } +func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe } +func (conn readOnlyConn) Close() error { return nil } +func (conn readOnlyConn) LocalAddr() net.Addr { return nil } +func (conn readOnlyConn) RemoteAddr() net.Addr { return nil } +func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil } +func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } +func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } + +// NewSNIProxy creates a new SNI proxy instance +func NewSNIProxy(port int, exitNodeName, localProxyAddr string, localProxyPort int, apiBaseURL string, localOverrides []string) (*SNIProxy, error) { + ctx, cancel := context.WithCancel(context.Background()) + + // Create local overrides map + overridesMap := make(map[string]struct{}) + for _, domain := range localOverrides { + if domain != "" { + overridesMap[domain] = struct{}{} + } + } + + proxy := &SNIProxy{ + port: port, + cache: cache.New(3*time.Second, 10*time.Minute), + ctx: ctx, + cancel: cancel, + exitNodeName: exitNodeName, + localProxyAddr: localProxyAddr, + localProxyPort: localProxyPort, + apiBaseURL: apiBaseURL, + localSNIs: make(map[string]struct{}), + localOverrides: overridesMap, + activeTunnels: make(map[string]*activeTunnel), + } + + return proxy, nil +} + +// Start begins listening for connections +func (p *SNIProxy) Start() error { + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", p.port)) + if err != nil { + return fmt.Errorf("failed to listen on port %d: %w", p.port, err) + } + + p.listener = listener + log.Printf("SNI Proxy listening on port %d", p.port) + + // Accept connections in a goroutine + go p.acceptConnections() + + return nil +} + +// Stop gracefully shuts down the proxy +func (p *SNIProxy) Stop() error { + log.Println("Stopping SNI Proxy...") + + p.cancel() + + if p.listener != nil { + p.listener.Close() + } + + // Wait for all goroutines to finish with timeout + done := make(chan struct{}) + go func() { + p.wg.Wait() + close(done) + }() + + select { + case <-done: + log.Println("All connections closed gracefully") + case <-time.After(30 * time.Second): + log.Println("Timeout waiting for connections to close") + } + + log.Println("SNI Proxy stopped") + return nil +} + +// acceptConnections handles incoming connections +func (p *SNIProxy) acceptConnections() { + for { + conn, err := p.listener.Accept() + if err != nil { + select { + case <-p.ctx.Done(): + return + default: + log.Printf("Accept error: %v", err) + continue + } + } + + p.wg.Add(1) + go p.handleConnection(conn) + } +} + +// readClientHello reads and parses the TLS ClientHello message +func (p *SNIProxy) readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) { + var hello *tls.ClientHelloInfo + err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{ + GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { + hello = new(tls.ClientHelloInfo) + *hello = *argHello + return nil, nil + }, + }).Handshake() + if hello == nil { + return nil, err + } + return hello, nil +} + +// peekClientHello reads the ClientHello while preserving the data for forwarding +func (p *SNIProxy) peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) { + peekedBytes := new(bytes.Buffer) + hello, err := p.readClientHello(io.TeeReader(reader, peekedBytes)) + if err != nil { + return nil, nil, err + } + return hello, io.MultiReader(peekedBytes, reader), nil +} + +// extractSNI extracts the SNI hostname from the TLS ClientHello +func (p *SNIProxy) extractSNI(conn net.Conn) (string, io.Reader, error) { + clientHello, clientReader, err := p.peekClientHello(conn) + if err != nil { + return "", nil, fmt.Errorf("failed to peek ClientHello: %w", err) + } + + if clientHello.ServerName == "" { + return "", clientReader, fmt.Errorf("no SNI hostname found in ClientHello") + } + + return clientHello.ServerName, clientReader, nil +} + +// handleConnection processes a single client connection +func (p *SNIProxy) handleConnection(clientConn net.Conn) { + defer p.wg.Done() + defer clientConn.Close() + + log.Printf("Accepted connection from %s", clientConn.RemoteAddr()) + + // Set read timeout for SNI extraction + if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + log.Printf("Failed to set read deadline: %v", err) + return + } + + // Extract SNI hostname + hostname, clientReader, err := p.extractSNI(clientConn) + if err != nil { + log.Printf("SNI extraction failed: %v", err) + return + } + + if hostname == "" { + log.Println("No SNI hostname found") + return + } + + log.Printf("SNI hostname detected: %s", hostname) + + // Remove read timeout for normal operation + if err := clientConn.SetReadDeadline(time.Time{}); err != nil { + log.Printf("Failed to clear read deadline: %v", err) + return + } + + // Get routing information + route, err := p.getRoute(hostname) + if err != nil { + log.Printf("Failed to get route for %s: %v", hostname, err) + return + } + + if route == nil { + log.Printf("No route found for hostname: %s", hostname) + return + } + + log.Printf("Routing %s to %s:%d", hostname, route.TargetHost, route.TargetPort) + + // Connect to target server + targetConn, err := net.DialTimeout("tcp", + fmt.Sprintf("%s:%d", route.TargetHost, route.TargetPort), + 10*time.Second) + if err != nil { + log.Printf("Failed to connect to target %s:%d: %v", + route.TargetHost, route.TargetPort, err) + return + } + defer targetConn.Close() + + log.Printf("Connected to target: %s:%d", route.TargetHost, route.TargetPort) + + // Track this tunnel by SNI + p.activeTunnelsLock.Lock() + tunnel, ok := p.activeTunnels[hostname] + if !ok { + tunnel = &activeTunnel{} + p.activeTunnels[hostname] = tunnel + } + tunnel.conns = append(tunnel.conns, clientConn) + p.activeTunnelsLock.Unlock() + + defer func() { + // Remove this conn from active tunnels + p.activeTunnelsLock.Lock() + if tunnel, ok := p.activeTunnels[hostname]; ok { + newConns := make([]net.Conn, 0, len(tunnel.conns)) + for _, c := range tunnel.conns { + if c != clientConn { + newConns = append(newConns, c) + } + } + if len(newConns) == 0 { + delete(p.activeTunnels, hostname) + } else { + tunnel.conns = newConns + } + } + p.activeTunnelsLock.Unlock() + }() + + // Start bidirectional data transfer + p.pipe(clientConn, targetConn, clientReader) +} + +// getRoute retrieves routing information for a hostname +func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { + // Check local overrides first + if _, isOverride := p.localOverrides[hostname]; isOverride { + log.Printf("Local override matched for hostname: %s", hostname) + return &RouteRecord{ + Hostname: hostname, + TargetHost: p.localProxyAddr, + TargetPort: p.localProxyPort, + }, nil + } + + // Fast path: check if hostname is in localSNIs + p.localSNIsLock.RLock() + _, isLocal := p.localSNIs[hostname] + p.localSNIsLock.RUnlock() + if isLocal { + return &RouteRecord{ + Hostname: hostname, + TargetHost: p.localProxyAddr, + TargetPort: p.localProxyPort, + }, nil + } + + // Check cache first + if cached, found := p.cache.Get(hostname); found { + if cached == nil { + return nil, nil // Cached negative result + } + log.Printf("Cache hit for hostname: %s", hostname) + return cached.(*RouteRecord), nil + } + + log.Printf("Cache miss for hostname: %s, querying API", hostname) + + // Query API with timeout + ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second) + defer cancel() + + // Construct API URL + apiURL := fmt.Sprintf("%s/api/route/%s", p.apiBaseURL, hostname) + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Make HTTP request + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("API request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + // Cache negative result for shorter time (1 minute) + p.cache.Set(hostname, nil, 1*time.Minute) + return nil, nil + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API returned status %d", resp.StatusCode) + } + + // Parse response + var apiResponse RouteAPIResponse + if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil { + return nil, fmt.Errorf("failed to decode API response: %w", err) + } + + endpoint := apiResponse.Endpoint + name := apiResponse.Name + + // If the endpoint matches the current exit node, use the local proxy address + targetHost := endpoint + targetPort := 443 // Default HTTPS port + if name == p.exitNodeName { + targetHost = p.localProxyAddr + targetPort = p.localProxyPort + } + + route := &RouteRecord{ + Hostname: hostname, + TargetHost: targetHost, + TargetPort: targetPort, + } + + // Cache the result + p.cache.Set(hostname, route, cache.DefaultExpiration) + log.Printf("Cached route for hostname: %s", hostname) + + return route, nil +} + +// pipe handles bidirectional data transfer between connections +func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) { + var wg sync.WaitGroup + wg.Add(2) + + // Copy data from client to target (using the buffered reader) + go func() { + defer wg.Done() + defer func() { + if tcpConn, ok := targetConn.(*net.TCPConn); ok { + tcpConn.CloseWrite() + } + }() + + // Use a large buffer for better performance + buf := make([]byte, 32*1024) + _, err := io.CopyBuffer(targetConn, clientReader, buf) + if err != nil && err != io.EOF { + log.Printf("Copy client->target error: %v", err) + } + }() + + // Copy data from target to client + go func() { + defer wg.Done() + defer func() { + if tcpConn, ok := clientConn.(*net.TCPConn); ok { + tcpConn.CloseWrite() + } + }() + + // Use a large buffer for better performance + buf := make([]byte, 32*1024) + _, err := io.CopyBuffer(clientConn, targetConn, buf) + if err != nil && err != io.EOF { + log.Printf("Copy target->client error: %v", err) + } + }() + + wg.Wait() +} + +// GetCacheStats returns cache statistics +func (p *SNIProxy) GetCacheStats() (int, int) { + return p.cache.ItemCount(), len(p.cache.Items()) +} + +// ClearCache clears all cached entries +func (p *SNIProxy) ClearCache() { + p.cache.Flush() + log.Println("Cache cleared") +} + +// UpdateLocalSNIs updates the local SNIs and invalidates cache for changed domains +func (p *SNIProxy) UpdateLocalSNIs(fullDomains []string) { + newSNIs := make(map[string]struct{}) + for _, domain := range fullDomains { + newSNIs[domain] = struct{}{} + // Invalidate any cached route for this domain + p.cache.Delete(domain) + } + + // Update localSNIs + p.localSNIsLock.Lock() + removed := make([]string, 0) + for sni := range p.localSNIs { + if _, stillLocal := newSNIs[sni]; !stillLocal { + removed = append(removed, sni) + } + } + p.localSNIs = newSNIs + p.localSNIsLock.Unlock() + + // Terminate tunnels for removed SNIs + if len(removed) > 0 { + p.activeTunnelsLock.Lock() + for _, sni := range removed { + if tunnels, ok := p.activeTunnels[sni]; ok { + for _, conn := range tunnels.conns { + conn.Close() + } + delete(p.activeTunnels, sni) + log.Printf("Closed tunnels for SNI target change: %s", sni) + } + } + p.activeTunnelsLock.Unlock() + } +} From 1df5eb19ff338d93e571bd7d84813f14a5bb7acc Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 13 Aug 2025 15:41:58 -0700 Subject: [PATCH 2/9] Integrate sni proxy --- main.go | 84 +++++++++++++++++++++++++++++++++++++++++--------- proxy/proxy.go | 48 ++++++++++++++--------------- 2 files changed, 94 insertions(+), 38 deletions(-) diff --git a/main.go b/main.go index b7b8ba3..d6fcf73 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ import ( "time" "github.com/fosrl/gerbil/logger" + "github.com/fosrl/gerbil/proxy" "github.com/fosrl/gerbil/relay" "github.com/vishvananda/netlink" "golang.zx2c4.com/wireguard/wgctrl" @@ -32,7 +33,8 @@ var ( mu sync.Mutex wgMu sync.Mutex // Protects WireGuard operations notifyURL string - proxyServer *relay.UDPProxyServer + proxyRelay *relay.UDPProxyServer + proxySNI *proxy.SNIProxy ) type WgConfig struct { @@ -115,6 +117,10 @@ func main() { reachableAt string logLevel string mtu string + sniProxyPort int + localProxyAddr string + localProxyPort int + localOverridesStr string ) interfaceName = os.Getenv("INTERFACE") @@ -127,6 +133,11 @@ func main() { mtu = os.Getenv("MTU") notifyURL = os.Getenv("NOTIFY_URL") + sniProxyPortStr := os.Getenv("SNI_PORT") + localProxyAddr = os.Getenv("LOCAL_PROXY") + localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT") + localOverridesStr = os.Getenv("LOCAL_OVERRIDES") + if interfaceName == "" { flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") } @@ -159,6 +170,32 @@ func main() { if notifyURL == "" { flag.StringVar(¬ifyURL, "notify", "", "URL to notify on peer changes") } + + if sniProxyPortStr != "" { + if port, err := strconv.Atoi(sniProxyPortStr); err == nil { + sniProxyPort = port + } + } + if sniProxyPortStr == "" { + flag.IntVar(&sniProxyPort, "sni-port", 8443, "Port to listen on") + } + + if localProxyAddr == "" { + flag.StringVar(&localProxyAddr, "local-proxy", "localhost", "Local proxy address") + } + + if localProxyPortStr != "" { + if port, err := strconv.Atoi(localProxyPortStr); err == nil { + localProxyPort = port + } + } + if localProxyPortStr == "" { + flag.IntVar(&localProxyPort, "local-proxy-port", 9443, "Local proxy port") + } + if localOverridesStr != "" { + flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy") + } + flag.Parse() logger.Init() @@ -258,14 +295,26 @@ func main() { go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth") // Start the UDP proxy server - proxyServer = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt) - err = proxyServer.Start() + proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt) + err = proxyRelay.Start() if err != nil { logger.Fatal("Failed to start UDP proxy server: %v", err) } - defer proxyServer.Stop() + defer proxyRelay.Stop() - proxySNI, err := NewSNIProxy(*port, config.Sidecar.ExitNodeName, *localProxyAddr, *localProxyPort, localOverrides) + // TODO: WE SHOULD PULL THIS OUT OF THE CONFIG OR SOMETHING + // SO YOU DON'T NEED TO SET THIS SEPARATELY + // Parse local overrides + var localOverrides []string + if localOverridesStr != "" { + localOverrides = strings.Split(localOverridesStr, ",") + for i, domain := range localOverrides { + localOverrides[i] = strings.TrimSpace(domain) + } + logger.Info("Local overrides configured: %v", localOverrides) + } + + proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, "david", localProxyAddr, localProxyPort, localOverrides) if err != nil { logger.Fatal("Failed to create proxy: %v", err) } @@ -278,6 +327,7 @@ func main() { http.HandleFunc("/peer", handlePeer) http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping) http.HandleFunc("/update-destinations", handleUpdateDestinations) + http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs) logger.Info("Starting HTTP server on %s", listenAddr) // Run HTTP server in a goroutine @@ -656,9 +706,9 @@ func addPeerInternal(peer Peer) error { } // Clear relay connections for the peer's WireGuard IPs - if proxyServer != nil { + if proxyRelay != nil { for _, wgIP := range wgIPs { - proxyServer.OnPeerAdded(wgIP) + proxyRelay.OnPeerAdded(wgIP) } } @@ -701,7 +751,7 @@ func removePeerInternal(publicKey string) error { // Get current peer info before removing to clear relay connections var wgIPs []string - if proxyServer != nil { + if proxyRelay != nil { device, err := wgClient.Device(interfaceName) if err == nil { for _, peer := range device.Peers { @@ -730,9 +780,9 @@ func removePeerInternal(publicKey string) error { } // Clear relay connections for the peer's WireGuard IPs - if proxyServer != nil { + if proxyRelay != nil { for _, wgIP := range wgIPs { - proxyServer.OnPeerRemoved(wgIP) + proxyRelay.OnPeerRemoved(wgIP) } } @@ -769,13 +819,13 @@ func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) { } // Update the proxy mappings in the relay server - if proxyServer == nil { + if proxyRelay == nil { logger.Error("Proxy server is not available") http.Error(w, "Proxy server is not available", http.StatusInternalServerError) return } - updatedCount := proxyServer.UpdateDestinationInMappings(update.OldDestination, update.NewDestination) + updatedCount := proxyRelay.UpdateDestinationInMappings(update.OldDestination, update.NewDestination) logger.Info("Updated %d proxy mappings: %s:%d -> %s:%d", updatedCount, @@ -839,13 +889,13 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) { } // Update the proxy mappings in the relay server - if proxyServer == nil { + if proxyRelay == nil { logger.Error("Proxy server is not available") http.Error(w, "Proxy server is not available", http.StatusInternalServerError) return } - proxyServer.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations) + proxyRelay.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations) logger.Info("Updated proxy mapping for %s:%d with %d destinations", request.SourceIP, request.SourcePort, len(request.Destinations)) @@ -878,6 +928,12 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) { return } + proxySNI.UpdateLocalSNIs(req.FullDomains) + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "Local SNIs updated successfully", + }) } func periodicBandwidthCheck(endpoint string) { diff --git a/proxy/proxy.go b/proxy/proxy.go index 485a019..b553ff8 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -1,4 +1,4 @@ -package main +package proxy import ( "bytes" @@ -31,16 +31,16 @@ type RouteAPIResponse struct { // SNIProxy represents the main proxy server type SNIProxy struct { - port int - cache *cache.Cache - listener net.Listener - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup - exitNodeName string - localProxyAddr string - localProxyPort int - apiBaseURL string + port int + cache *cache.Cache + listener net.Listener + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + exitNodeName string + localProxyAddr string + localProxyPort int + remoteConfigURL string // New fields for fast local SNI lookup localSNIs map[string]struct{} @@ -73,7 +73,7 @@ func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } // NewSNIProxy creates a new SNI proxy instance -func NewSNIProxy(port int, exitNodeName, localProxyAddr string, localProxyPort int, apiBaseURL string, localOverrides []string) (*SNIProxy, error) { +func NewSNIProxy(port int, remoteConfigURL, exitNodeName, localProxyAddr string, localProxyPort int, localOverrides []string) (*SNIProxy, error) { ctx, cancel := context.WithCancel(context.Background()) // Create local overrides map @@ -85,17 +85,17 @@ func NewSNIProxy(port int, exitNodeName, localProxyAddr string, localProxyPort i } proxy := &SNIProxy{ - port: port, - cache: cache.New(3*time.Second, 10*time.Minute), - ctx: ctx, - cancel: cancel, - exitNodeName: exitNodeName, - localProxyAddr: localProxyAddr, - localProxyPort: localProxyPort, - apiBaseURL: apiBaseURL, - localSNIs: make(map[string]struct{}), - localOverrides: overridesMap, - activeTunnels: make(map[string]*activeTunnel), + port: port, + cache: cache.New(3*time.Second, 10*time.Minute), + ctx: ctx, + cancel: cancel, + exitNodeName: exitNodeName, + localProxyAddr: localProxyAddr, + localProxyPort: localProxyPort, + remoteConfigURL: remoteConfigURL, + localSNIs: make(map[string]struct{}), + localOverrides: overridesMap, + activeTunnels: make(map[string]*activeTunnel), } return proxy, nil @@ -337,7 +337,7 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { defer cancel() // Construct API URL - apiURL := fmt.Sprintf("%s/api/route/%s", p.apiBaseURL, hostname) + apiURL := fmt.Sprintf("%s/api/route/%s", p.remoteConfigURL, hostname) // Create HTTP request req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) From 10958f8c5589770110d41b6b03caaeacdb5fa065 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 14 Aug 2025 22:25:38 -0700 Subject: [PATCH 3/9] Use propper logger --- proxy/proxy.go | 41 ++++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index b553ff8..6079c29 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -13,6 +13,7 @@ import ( "sync" "time" + "github.com/fosrl/gerbil/logger" "github.com/patrickmn/go-cache" ) @@ -109,7 +110,7 @@ func (p *SNIProxy) Start() error { } p.listener = listener - log.Printf("SNI Proxy listening on port %d", p.port) + logger.Debug("SNI Proxy listening on port %d", p.port) // Accept connections in a goroutine go p.acceptConnections() @@ -154,7 +155,7 @@ func (p *SNIProxy) acceptConnections() { case <-p.ctx.Done(): return default: - log.Printf("Accept error: %v", err) + logger.Debug("Accept error: %v", err) continue } } @@ -209,18 +210,18 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { defer p.wg.Done() defer clientConn.Close() - log.Printf("Accepted connection from %s", clientConn.RemoteAddr()) + logger.Debug("Accepted connection from %s", clientConn.RemoteAddr()) // Set read timeout for SNI extraction if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { - log.Printf("Failed to set read deadline: %v", err) + logger.Debug("Failed to set read deadline: %v", err) return } // Extract SNI hostname hostname, clientReader, err := p.extractSNI(clientConn) if err != nil { - log.Printf("SNI extraction failed: %v", err) + logger.Debug("SNI extraction failed: %v", err) return } @@ -229,40 +230,40 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { return } - log.Printf("SNI hostname detected: %s", hostname) + logger.Debug("SNI hostname detected: %s", hostname) // Remove read timeout for normal operation if err := clientConn.SetReadDeadline(time.Time{}); err != nil { - log.Printf("Failed to clear read deadline: %v", err) + logger.Debug("Failed to clear read deadline: %v", err) return } // Get routing information route, err := p.getRoute(hostname) if err != nil { - log.Printf("Failed to get route for %s: %v", hostname, err) + logger.Debug("Failed to get route for %s: %v", hostname, err) return } if route == nil { - log.Printf("No route found for hostname: %s", hostname) + logger.Debug("No route found for hostname: %s", hostname) return } - log.Printf("Routing %s to %s:%d", hostname, route.TargetHost, route.TargetPort) + logger.Debug("Routing %s to %s:%d", hostname, route.TargetHost, route.TargetPort) // Connect to target server targetConn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", route.TargetHost, route.TargetPort), 10*time.Second) if err != nil { - log.Printf("Failed to connect to target %s:%d: %v", + logger.Debug("Failed to connect to target %s:%d: %v", route.TargetHost, route.TargetPort, err) return } defer targetConn.Close() - log.Printf("Connected to target: %s:%d", route.TargetHost, route.TargetPort) + logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort) // Track this tunnel by SNI p.activeTunnelsLock.Lock() @@ -301,7 +302,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { // Check local overrides first if _, isOverride := p.localOverrides[hostname]; isOverride { - log.Printf("Local override matched for hostname: %s", hostname) + logger.Debug("Local override matched for hostname: %s", hostname) return &RouteRecord{ Hostname: hostname, TargetHost: p.localProxyAddr, @@ -326,11 +327,11 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { if cached == nil { return nil, nil // Cached negative result } - log.Printf("Cache hit for hostname: %s", hostname) + logger.Debug("Cache hit for hostname: %s", hostname) return cached.(*RouteRecord), nil } - log.Printf("Cache miss for hostname: %s, querying API", hostname) + logger.Debug("Cache miss for hostname: %s, querying API", hostname) // Query API with timeout ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second) @@ -388,7 +389,7 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { // Cache the result p.cache.Set(hostname, route, cache.DefaultExpiration) - log.Printf("Cached route for hostname: %s", hostname) + logger.Debug("Cached route for hostname: %s", hostname) return route, nil } @@ -411,7 +412,7 @@ func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) buf := make([]byte, 32*1024) _, err := io.CopyBuffer(targetConn, clientReader, buf) if err != nil && err != io.EOF { - log.Printf("Copy client->target error: %v", err) + logger.Debug("Copy client->target error: %v", err) } }() @@ -428,7 +429,7 @@ func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) buf := make([]byte, 32*1024) _, err := io.CopyBuffer(clientConn, targetConn, buf) if err != nil && err != io.EOF { - log.Printf("Copy target->client error: %v", err) + logger.Debug("Copy target->client error: %v", err) } }() @@ -466,6 +467,8 @@ func (p *SNIProxy) UpdateLocalSNIs(fullDomains []string) { p.localSNIs = newSNIs p.localSNIsLock.Unlock() + logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed)) + // Terminate tunnels for removed SNIs if len(removed) > 0 { p.activeTunnelsLock.Lock() @@ -475,7 +478,7 @@ func (p *SNIProxy) UpdateLocalSNIs(fullDomains []string) { conn.Close() } delete(p.activeTunnels, sni) - log.Printf("Closed tunnels for SNI target change: %s", sni) + logger.Debug("Closed tunnels for SNI target change: %s", sni) } } p.activeTunnelsLock.Unlock() From 9de3f14799e901c9df02a3ccc8d009c823cc2dd9 Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 16 Aug 2025 22:35:51 -0700 Subject: [PATCH 4/9] Update default config --- README.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 830085b..54c3b55 100644 --- a/README.md +++ b/README.md @@ -62,8 +62,7 @@ Example: ./gerbil \ --reachableAt=http://gerbil:3003 \ --generateAndSaveKeyTo=/var/config/key \ ---remoteConfig=http://pangolin:3001/api/v1/gerbil/get-config \ ---reportBandwidthTo=http://pangolin:3001/api/v1/gerbil/receive-bandwidth +--remoteConfig=http://pangolin:3001/api/v1/ ``` ```yaml @@ -75,8 +74,7 @@ services: command: - --reachableAt=http://gerbil:3003 - --generateAndSaveKeyTo=/var/config/key - - --remoteConfig=http://pangolin:3001/api/v1/gerbil/get-config - - --reportBandwidthTo=http://pangolin:3001/api/v1/gerbil/receive-bandwidth + - --remoteConfig=http://pangolin:3001/api/v1/ volumes: - ./config/:/var/config cap_add: From c24537af36da516f77f8523e778bbb9358e7d857 Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 16 Aug 2025 22:36:03 -0700 Subject: [PATCH 5/9] Fix url --- proxy/proxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index 6079c29..ba9188c 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -338,7 +338,7 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { defer cancel() // Construct API URL - apiURL := fmt.Sprintf("%s/api/route/%s", p.remoteConfigURL, hostname) + apiURL := fmt.Sprintf("%s/gerbil/get-resolved-hostname/%s", p.remoteConfigURL, hostname) // Create HTTP request req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) From 09bd02456dcf29ecfb9315381786d35b8931e4be Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 16 Aug 2025 22:53:49 -0700 Subject: [PATCH 6/9] Move to post --- proxy/proxy.go | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index ba9188c..dee870c 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -26,8 +26,7 @@ type RouteRecord struct { // RouteAPIResponse represents the response from the route API type RouteAPIResponse struct { - Endpoint string `json:"endpoint"` - Name string `json:"name"` + Endpoints []string `json:"endpoints"` } // SNIProxy represents the main proxy server @@ -42,6 +41,7 @@ type SNIProxy struct { localProxyAddr string localProxyPort int remoteConfigURL string + publicKey string // New fields for fast local SNI lookup localSNIs map[string]struct{} @@ -74,7 +74,7 @@ func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } // NewSNIProxy creates a new SNI proxy instance -func NewSNIProxy(port int, remoteConfigURL, exitNodeName, localProxyAddr string, localProxyPort int, localOverrides []string) (*SNIProxy, error) { +func NewSNIProxy(port int, remoteConfigURL, publicKey, exitNodeName, localProxyAddr string, localProxyPort int, localOverrides []string) (*SNIProxy, error) { ctx, cancel := context.WithCancel(context.Background()) // Create local overrides map @@ -94,6 +94,7 @@ func NewSNIProxy(port int, remoteConfigURL, exitNodeName, localProxyAddr string, localProxyAddr: localProxyAddr, localProxyPort: localProxyPort, remoteConfigURL: remoteConfigURL, + publicKey: publicKey, localSNIs: make(map[string]struct{}), localOverrides: overridesMap, activeTunnels: make(map[string]*activeTunnel), @@ -337,14 +338,26 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second) defer cancel() - // Construct API URL - apiURL := fmt.Sprintf("%s/gerbil/get-resolved-hostname/%s", p.remoteConfigURL, hostname) + // Construct API URL (without hostname in path) + apiURL := fmt.Sprintf("%s/gerbil/get-resolved-hostname", p.remoteConfigURL) + + // Create request body with hostname and public key + requestBody := map[string]string{ + "hostname": hostname, + "publicKey": p.publicKey, + } + + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } // Create HTTP request - req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) + req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonBody)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } + req.Header.Set("Content-Type", "application/json") // Make HTTP request client := &http.Client{Timeout: 5 * time.Second} @@ -370,7 +383,7 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { return nil, fmt.Errorf("failed to decode API response: %w", err) } - endpoint := apiResponse.Endpoint + endpoints := apiResponse.Endpoints name := apiResponse.Name // If the endpoint matches the current exit node, use the local proxy address @@ -379,7 +392,7 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { if name == p.exitNodeName { targetHost = p.localProxyAddr targetPort = p.localProxyPort - } + } // THIS IS SAYING TO ROUTE IT LOCALLY IF IT MATCHES - idk HOW TO KEEP THIS route := &RouteRecord{ Hostname: hostname, From c970fd5a18623e1ec61f0bbbd02f92b2dc37bc89 Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 16 Aug 2025 22:59:45 -0700 Subject: [PATCH 7/9] Update to work with multipe endpoints --- proxy/proxy.go | 48 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index dee870c..14730ee 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "encoding/json" "fmt" + "hash/fnv" "io" "log" "net" @@ -27,6 +28,7 @@ type RouteRecord struct { // RouteAPIResponse represents the response from the route API type RouteAPIResponse struct { Endpoints []string `json:"endpoints"` + Name string `json:"name"` } // SNIProxy represents the main proxy server @@ -240,7 +242,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { } // Get routing information - route, err := p.getRoute(hostname) + route, err := p.getRoute(hostname, clientConn.RemoteAddr().String()) if err != nil { logger.Debug("Failed to get route for %s: %v", hostname, err) return @@ -300,7 +302,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { } // getRoute retrieves routing information for a hostname -func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { +func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { // Check local overrides first if _, isOverride := p.localOverrides[hostname]; isOverride { logger.Debug("Local override matched for hostname: %s", hostname) @@ -386,13 +388,23 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { endpoints := apiResponse.Endpoints name := apiResponse.Name - // If the endpoint matches the current exit node, use the local proxy address - targetHost := endpoint - targetPort := 443 // Default HTTPS port - if name == p.exitNodeName { - targetHost = p.localProxyAddr - targetPort = p.localProxyPort - } // THIS IS SAYING TO ROUTE IT LOCALLY IF IT MATCHES - idk HOW TO KEEP THIS + // Default target configuration + targetHost := p.localProxyAddr + targetPort := p.localProxyPort + + // If no endpoints returned, use local node + if len(endpoints) == 0 { + logger.Debug("No endpoints returned for hostname: %s, using local node", hostname) + } else if name == p.exitNodeName { + // If the endpoint matches the current exit node, use the local proxy address + logger.Debug("Exit node name matches current node (%s), using local proxy", name) + } else { + // Select endpoint using consistent hashing for stickiness + selectedEndpoint := p.selectStickyEndpoint(clientAddr, endpoints) + targetHost = selectedEndpoint + targetPort = 443 // Default HTTPS port + logger.Debug("Selected endpoint %s for hostname %s from client %s", selectedEndpoint, hostname, clientAddr) + } route := &RouteRecord{ Hostname: hostname, @@ -407,6 +419,24 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { return route, nil } +// selectStickyEndpoint selects an endpoint using consistent hashing to ensure +// the same client always routes to the same endpoint for load balancing +func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) string { + if len(endpoints) == 0 { + return p.localProxyAddr + } + if len(endpoints) == 1 { + return endpoints[0] + } + + // Use FNV hash for consistent selection based on client address + hash := fnv.New32a() + hash.Write([]byte(clientAddr)) + index := hash.Sum32() % uint32(len(endpoints)) + + return endpoints[index] +} + // pipe handles bidirectional data transfer between connections func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) { var wg sync.WaitGroup From efce3cb0b2755eb837d3d82e5bcedac042ec99dd Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 17 Aug 2025 10:43:37 -0700 Subject: [PATCH 8/9] Sni has no errors now --- main.go | 2 +- proxy/proxy.go | 9 +-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/main.go b/main.go index d6fcf73..37996e6 100644 --- a/main.go +++ b/main.go @@ -314,7 +314,7 @@ func main() { logger.Info("Local overrides configured: %v", localOverrides) } - proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, "david", localProxyAddr, localProxyPort, localOverrides) + proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides) if err != nil { logger.Fatal("Failed to create proxy: %v", err) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 14730ee..7edd92e 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -28,7 +28,6 @@ type RouteRecord struct { // RouteAPIResponse represents the response from the route API type RouteAPIResponse struct { Endpoints []string `json:"endpoints"` - Name string `json:"name"` } // SNIProxy represents the main proxy server @@ -39,7 +38,6 @@ type SNIProxy struct { ctx context.Context cancel context.CancelFunc wg sync.WaitGroup - exitNodeName string localProxyAddr string localProxyPort int remoteConfigURL string @@ -76,7 +74,7 @@ func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } // NewSNIProxy creates a new SNI proxy instance -func NewSNIProxy(port int, remoteConfigURL, publicKey, exitNodeName, localProxyAddr string, localProxyPort int, localOverrides []string) (*SNIProxy, error) { +func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string) (*SNIProxy, error) { ctx, cancel := context.WithCancel(context.Background()) // Create local overrides map @@ -92,7 +90,6 @@ func NewSNIProxy(port int, remoteConfigURL, publicKey, exitNodeName, localProxyA cache: cache.New(3*time.Second, 10*time.Minute), ctx: ctx, cancel: cancel, - exitNodeName: exitNodeName, localProxyAddr: localProxyAddr, localProxyPort: localProxyPort, remoteConfigURL: remoteConfigURL, @@ -386,7 +383,6 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { } endpoints := apiResponse.Endpoints - name := apiResponse.Name // Default target configuration targetHost := p.localProxyAddr @@ -395,9 +391,6 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { // If no endpoints returned, use local node if len(endpoints) == 0 { logger.Debug("No endpoints returned for hostname: %s, using local node", hostname) - } else if name == p.exitNodeName { - // If the endpoint matches the current exit node, use the local proxy address - logger.Debug("Exit node name matches current node (%s), using local proxy", name) } else { // Select endpoint using consistent hashing for stickiness selectedEndpoint := p.selectStickyEndpoint(clientAddr, endpoints) From f983a8f141a3111fcf100709f3f9e3ac7cddeb48 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 22 Aug 2025 11:56:29 -0700 Subject: [PATCH 9/9] Local proxy port 443 --- main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.go b/main.go index 37996e6..37350ce 100644 --- a/main.go +++ b/main.go @@ -190,7 +190,7 @@ func main() { } } if localProxyPortStr == "" { - flag.IntVar(&localProxyPort, "local-proxy-port", 9443, "Local proxy port") + flag.IntVar(&localProxyPort, "local-proxy-port", 443, "Local proxy port") } if localOverridesStr != "" { flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy")