mirror of
https://github.com/fosrl/gerbil.git
synced 2026-02-08 05:56:40 +00:00
Update to work with multipe endpoints
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
@@ -27,6 +28,7 @@ type RouteRecord struct {
|
||||
// RouteAPIResponse represents the response from the route API
|
||||
type RouteAPIResponse struct {
|
||||
Endpoints []string `json:"endpoints"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// SNIProxy represents the main proxy server
|
||||
@@ -240,7 +242,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
}
|
||||
|
||||
// Get routing information
|
||||
route, err := p.getRoute(hostname)
|
||||
route, err := p.getRoute(hostname, clientConn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
logger.Debug("Failed to get route for %s: %v", hostname, err)
|
||||
return
|
||||
@@ -300,7 +302,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
}
|
||||
|
||||
// getRoute retrieves routing information for a hostname
|
||||
func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) {
|
||||
func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
|
||||
// Check local overrides first
|
||||
if _, isOverride := p.localOverrides[hostname]; isOverride {
|
||||
logger.Debug("Local override matched for hostname: %s", hostname)
|
||||
@@ -386,13 +388,23 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) {
|
||||
endpoints := apiResponse.Endpoints
|
||||
name := apiResponse.Name
|
||||
|
||||
// If the endpoint matches the current exit node, use the local proxy address
|
||||
targetHost := endpoint
|
||||
targetPort := 443 // Default HTTPS port
|
||||
if name == p.exitNodeName {
|
||||
targetHost = p.localProxyAddr
|
||||
targetPort = p.localProxyPort
|
||||
} // THIS IS SAYING TO ROUTE IT LOCALLY IF IT MATCHES - idk HOW TO KEEP THIS
|
||||
// Default target configuration
|
||||
targetHost := p.localProxyAddr
|
||||
targetPort := p.localProxyPort
|
||||
|
||||
// If no endpoints returned, use local node
|
||||
if len(endpoints) == 0 {
|
||||
logger.Debug("No endpoints returned for hostname: %s, using local node", hostname)
|
||||
} else if name == p.exitNodeName {
|
||||
// If the endpoint matches the current exit node, use the local proxy address
|
||||
logger.Debug("Exit node name matches current node (%s), using local proxy", name)
|
||||
} else {
|
||||
// Select endpoint using consistent hashing for stickiness
|
||||
selectedEndpoint := p.selectStickyEndpoint(clientAddr, endpoints)
|
||||
targetHost = selectedEndpoint
|
||||
targetPort = 443 // Default HTTPS port
|
||||
logger.Debug("Selected endpoint %s for hostname %s from client %s", selectedEndpoint, hostname, clientAddr)
|
||||
}
|
||||
|
||||
route := &RouteRecord{
|
||||
Hostname: hostname,
|
||||
@@ -407,6 +419,24 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) {
|
||||
return route, nil
|
||||
}
|
||||
|
||||
// selectStickyEndpoint selects an endpoint using consistent hashing to ensure
|
||||
// the same client always routes to the same endpoint for load balancing
|
||||
func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) string {
|
||||
if len(endpoints) == 0 {
|
||||
return p.localProxyAddr
|
||||
}
|
||||
if len(endpoints) == 1 {
|
||||
return endpoints[0]
|
||||
}
|
||||
|
||||
// Use FNV hash for consistent selection based on client address
|
||||
hash := fnv.New32a()
|
||||
hash.Write([]byte(clientAddr))
|
||||
index := hash.Sum32() % uint32(len(endpoints))
|
||||
|
||||
return endpoints[index]
|
||||
}
|
||||
|
||||
// pipe handles bidirectional data transfer between connections
|
||||
func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
Reference in New Issue
Block a user