diff --git a/proxy/proxy.go b/proxy/proxy.go index dee870c..14730ee 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "encoding/json" "fmt" + "hash/fnv" "io" "log" "net" @@ -27,6 +28,7 @@ type RouteRecord struct { // RouteAPIResponse represents the response from the route API type RouteAPIResponse struct { Endpoints []string `json:"endpoints"` + Name string `json:"name"` } // SNIProxy represents the main proxy server @@ -240,7 +242,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { } // Get routing information - route, err := p.getRoute(hostname) + route, err := p.getRoute(hostname, clientConn.RemoteAddr().String()) if err != nil { logger.Debug("Failed to get route for %s: %v", hostname, err) return @@ -300,7 +302,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { } // getRoute retrieves routing information for a hostname -func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { +func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { // Check local overrides first if _, isOverride := p.localOverrides[hostname]; isOverride { logger.Debug("Local override matched for hostname: %s", hostname) @@ -386,13 +388,23 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { endpoints := apiResponse.Endpoints name := apiResponse.Name - // If the endpoint matches the current exit node, use the local proxy address - targetHost := endpoint - targetPort := 443 // Default HTTPS port - if name == p.exitNodeName { - targetHost = p.localProxyAddr - targetPort = p.localProxyPort - } // THIS IS SAYING TO ROUTE IT LOCALLY IF IT MATCHES - idk HOW TO KEEP THIS + // Default target configuration + targetHost := p.localProxyAddr + targetPort := p.localProxyPort + + // If no endpoints returned, use local node + if len(endpoints) == 0 { + logger.Debug("No endpoints returned for hostname: %s, using local node", hostname) + } else if name == p.exitNodeName { + // If the endpoint matches the current exit node, use the local proxy address + logger.Debug("Exit node name matches current node (%s), using local proxy", name) + } else { + // Select endpoint using consistent hashing for stickiness + selectedEndpoint := p.selectStickyEndpoint(clientAddr, endpoints) + targetHost = selectedEndpoint + targetPort = 443 // Default HTTPS port + logger.Debug("Selected endpoint %s for hostname %s from client %s", selectedEndpoint, hostname, clientAddr) + } route := &RouteRecord{ Hostname: hostname, @@ -407,6 +419,24 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { return route, nil } +// selectStickyEndpoint selects an endpoint using consistent hashing to ensure +// the same client always routes to the same endpoint for load balancing +func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) string { + if len(endpoints) == 0 { + return p.localProxyAddr + } + if len(endpoints) == 1 { + return endpoints[0] + } + + // Use FNV hash for consistent selection based on client address + hash := fnv.New32a() + hash.Write([]byte(clientAddr)) + index := hash.Sum32() % uint32(len(endpoints)) + + return endpoints[index] +} + // pipe handles bidirectional data transfer between connections func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) { var wg sync.WaitGroup