Update to work with multipe endpoints

This commit is contained in:
Owen
2025-08-16 22:59:45 -07:00
parent 09bd02456d
commit c970fd5a18

View File

@@ -6,6 +6,7 @@ import (
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
"hash/fnv"
"io" "io"
"log" "log"
"net" "net"
@@ -27,6 +28,7 @@ type RouteRecord struct {
// RouteAPIResponse represents the response from the route API // RouteAPIResponse represents the response from the route API
type RouteAPIResponse struct { type RouteAPIResponse struct {
Endpoints []string `json:"endpoints"` Endpoints []string `json:"endpoints"`
Name string `json:"name"`
} }
// SNIProxy represents the main proxy server // SNIProxy represents the main proxy server
@@ -240,7 +242,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
} }
// Get routing information // Get routing information
route, err := p.getRoute(hostname) route, err := p.getRoute(hostname, clientConn.RemoteAddr().String())
if err != nil { if err != nil {
logger.Debug("Failed to get route for %s: %v", hostname, err) logger.Debug("Failed to get route for %s: %v", hostname, err)
return return
@@ -300,7 +302,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
} }
// getRoute retrieves routing information for a hostname // getRoute retrieves routing information for a hostname
func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) { func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
// Check local overrides first // Check local overrides first
if _, isOverride := p.localOverrides[hostname]; isOverride { if _, isOverride := p.localOverrides[hostname]; isOverride {
logger.Debug("Local override matched for hostname: %s", hostname) logger.Debug("Local override matched for hostname: %s", hostname)
@@ -386,13 +388,23 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) {
endpoints := apiResponse.Endpoints endpoints := apiResponse.Endpoints
name := apiResponse.Name name := apiResponse.Name
// If the endpoint matches the current exit node, use the local proxy address // Default target configuration
targetHost := endpoint targetHost := p.localProxyAddr
targetPort := 443 // Default HTTPS port targetPort := p.localProxyPort
if name == p.exitNodeName {
targetHost = p.localProxyAddr // If no endpoints returned, use local node
targetPort = p.localProxyPort if len(endpoints) == 0 {
} // THIS IS SAYING TO ROUTE IT LOCALLY IF IT MATCHES - idk HOW TO KEEP THIS logger.Debug("No endpoints returned for hostname: %s, using local node", hostname)
} else if name == p.exitNodeName {
// If the endpoint matches the current exit node, use the local proxy address
logger.Debug("Exit node name matches current node (%s), using local proxy", name)
} else {
// Select endpoint using consistent hashing for stickiness
selectedEndpoint := p.selectStickyEndpoint(clientAddr, endpoints)
targetHost = selectedEndpoint
targetPort = 443 // Default HTTPS port
logger.Debug("Selected endpoint %s for hostname %s from client %s", selectedEndpoint, hostname, clientAddr)
}
route := &RouteRecord{ route := &RouteRecord{
Hostname: hostname, Hostname: hostname,
@@ -407,6 +419,24 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) {
return route, nil return route, nil
} }
// selectStickyEndpoint selects an endpoint using consistent hashing to ensure
// the same client always routes to the same endpoint for load balancing
func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) string {
if len(endpoints) == 0 {
return p.localProxyAddr
}
if len(endpoints) == 1 {
return endpoints[0]
}
// Use FNV hash for consistent selection based on client address
hash := fnv.New32a()
hash.Write([]byte(clientAddr))
index := hash.Sum32() % uint32(len(endpoints))
return endpoints[index]
}
// pipe handles bidirectional data transfer between connections // pipe handles bidirectional data transfer between connections
func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) { func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) {
var wg sync.WaitGroup var wg sync.WaitGroup