mirror of
https://github.com/fosrl/gerbil.git
synced 2026-02-26 14:56:45 +00:00
Integrate sni proxy
This commit is contained in:
84
main.go
84
main.go
@@ -18,6 +18,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/gerbil/logger"
|
"github.com/fosrl/gerbil/logger"
|
||||||
|
"github.com/fosrl/gerbil/proxy"
|
||||||
"github.com/fosrl/gerbil/relay"
|
"github.com/fosrl/gerbil/relay"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
@@ -32,7 +33,8 @@ var (
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
wgMu sync.Mutex // Protects WireGuard operations
|
wgMu sync.Mutex // Protects WireGuard operations
|
||||||
notifyURL string
|
notifyURL string
|
||||||
proxyServer *relay.UDPProxyServer
|
proxyRelay *relay.UDPProxyServer
|
||||||
|
proxySNI *proxy.SNIProxy
|
||||||
)
|
)
|
||||||
|
|
||||||
type WgConfig struct {
|
type WgConfig struct {
|
||||||
@@ -115,6 +117,10 @@ func main() {
|
|||||||
reachableAt string
|
reachableAt string
|
||||||
logLevel string
|
logLevel string
|
||||||
mtu string
|
mtu string
|
||||||
|
sniProxyPort int
|
||||||
|
localProxyAddr string
|
||||||
|
localProxyPort int
|
||||||
|
localOverridesStr string
|
||||||
)
|
)
|
||||||
|
|
||||||
interfaceName = os.Getenv("INTERFACE")
|
interfaceName = os.Getenv("INTERFACE")
|
||||||
@@ -127,6 +133,11 @@ func main() {
|
|||||||
mtu = os.Getenv("MTU")
|
mtu = os.Getenv("MTU")
|
||||||
notifyURL = os.Getenv("NOTIFY_URL")
|
notifyURL = os.Getenv("NOTIFY_URL")
|
||||||
|
|
||||||
|
sniProxyPortStr := os.Getenv("SNI_PORT")
|
||||||
|
localProxyAddr = os.Getenv("LOCAL_PROXY")
|
||||||
|
localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT")
|
||||||
|
localOverridesStr = os.Getenv("LOCAL_OVERRIDES")
|
||||||
|
|
||||||
if interfaceName == "" {
|
if interfaceName == "" {
|
||||||
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
|
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
|
||||||
}
|
}
|
||||||
@@ -159,6 +170,32 @@ func main() {
|
|||||||
if notifyURL == "" {
|
if notifyURL == "" {
|
||||||
flag.StringVar(¬ifyURL, "notify", "", "URL to notify on peer changes")
|
flag.StringVar(¬ifyURL, "notify", "", "URL to notify on peer changes")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if sniProxyPortStr != "" {
|
||||||
|
if port, err := strconv.Atoi(sniProxyPortStr); err == nil {
|
||||||
|
sniProxyPort = port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sniProxyPortStr == "" {
|
||||||
|
flag.IntVar(&sniProxyPort, "sni-port", 8443, "Port to listen on")
|
||||||
|
}
|
||||||
|
|
||||||
|
if localProxyAddr == "" {
|
||||||
|
flag.StringVar(&localProxyAddr, "local-proxy", "localhost", "Local proxy address")
|
||||||
|
}
|
||||||
|
|
||||||
|
if localProxyPortStr != "" {
|
||||||
|
if port, err := strconv.Atoi(localProxyPortStr); err == nil {
|
||||||
|
localProxyPort = port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if localProxyPortStr == "" {
|
||||||
|
flag.IntVar(&localProxyPort, "local-proxy-port", 9443, "Local proxy port")
|
||||||
|
}
|
||||||
|
if localOverridesStr != "" {
|
||||||
|
flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy")
|
||||||
|
}
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
logger.Init()
|
logger.Init()
|
||||||
@@ -258,14 +295,26 @@ func main() {
|
|||||||
go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth")
|
go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth")
|
||||||
|
|
||||||
// Start the UDP proxy server
|
// Start the UDP proxy server
|
||||||
proxyServer = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
|
proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
|
||||||
err = proxyServer.Start()
|
err = proxyRelay.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to start UDP proxy server: %v", err)
|
logger.Fatal("Failed to start UDP proxy server: %v", err)
|
||||||
}
|
}
|
||||||
defer proxyServer.Stop()
|
defer proxyRelay.Stop()
|
||||||
|
|
||||||
proxySNI, err := NewSNIProxy(*port, config.Sidecar.ExitNodeName, *localProxyAddr, *localProxyPort, localOverrides)
|
// TODO: WE SHOULD PULL THIS OUT OF THE CONFIG OR SOMETHING
|
||||||
|
// SO YOU DON'T NEED TO SET THIS SEPARATELY
|
||||||
|
// Parse local overrides
|
||||||
|
var localOverrides []string
|
||||||
|
if localOverridesStr != "" {
|
||||||
|
localOverrides = strings.Split(localOverridesStr, ",")
|
||||||
|
for i, domain := range localOverrides {
|
||||||
|
localOverrides[i] = strings.TrimSpace(domain)
|
||||||
|
}
|
||||||
|
logger.Info("Local overrides configured: %v", localOverrides)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, "david", localProxyAddr, localProxyPort, localOverrides)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to create proxy: %v", err)
|
logger.Fatal("Failed to create proxy: %v", err)
|
||||||
}
|
}
|
||||||
@@ -278,6 +327,7 @@ func main() {
|
|||||||
http.HandleFunc("/peer", handlePeer)
|
http.HandleFunc("/peer", handlePeer)
|
||||||
http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping)
|
http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping)
|
||||||
http.HandleFunc("/update-destinations", handleUpdateDestinations)
|
http.HandleFunc("/update-destinations", handleUpdateDestinations)
|
||||||
|
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
|
||||||
logger.Info("Starting HTTP server on %s", listenAddr)
|
logger.Info("Starting HTTP server on %s", listenAddr)
|
||||||
|
|
||||||
// Run HTTP server in a goroutine
|
// Run HTTP server in a goroutine
|
||||||
@@ -656,9 +706,9 @@ func addPeerInternal(peer Peer) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Clear relay connections for the peer's WireGuard IPs
|
// Clear relay connections for the peer's WireGuard IPs
|
||||||
if proxyServer != nil {
|
if proxyRelay != nil {
|
||||||
for _, wgIP := range wgIPs {
|
for _, wgIP := range wgIPs {
|
||||||
proxyServer.OnPeerAdded(wgIP)
|
proxyRelay.OnPeerAdded(wgIP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -701,7 +751,7 @@ func removePeerInternal(publicKey string) error {
|
|||||||
|
|
||||||
// Get current peer info before removing to clear relay connections
|
// Get current peer info before removing to clear relay connections
|
||||||
var wgIPs []string
|
var wgIPs []string
|
||||||
if proxyServer != nil {
|
if proxyRelay != nil {
|
||||||
device, err := wgClient.Device(interfaceName)
|
device, err := wgClient.Device(interfaceName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for _, peer := range device.Peers {
|
for _, peer := range device.Peers {
|
||||||
@@ -730,9 +780,9 @@ func removePeerInternal(publicKey string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Clear relay connections for the peer's WireGuard IPs
|
// Clear relay connections for the peer's WireGuard IPs
|
||||||
if proxyServer != nil {
|
if proxyRelay != nil {
|
||||||
for _, wgIP := range wgIPs {
|
for _, wgIP := range wgIPs {
|
||||||
proxyServer.OnPeerRemoved(wgIP)
|
proxyRelay.OnPeerRemoved(wgIP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -769,13 +819,13 @@ func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update the proxy mappings in the relay server
|
// Update the proxy mappings in the relay server
|
||||||
if proxyServer == nil {
|
if proxyRelay == nil {
|
||||||
logger.Error("Proxy server is not available")
|
logger.Error("Proxy server is not available")
|
||||||
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
|
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedCount := proxyServer.UpdateDestinationInMappings(update.OldDestination, update.NewDestination)
|
updatedCount := proxyRelay.UpdateDestinationInMappings(update.OldDestination, update.NewDestination)
|
||||||
|
|
||||||
logger.Info("Updated %d proxy mappings: %s:%d -> %s:%d",
|
logger.Info("Updated %d proxy mappings: %s:%d -> %s:%d",
|
||||||
updatedCount,
|
updatedCount,
|
||||||
@@ -839,13 +889,13 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update the proxy mappings in the relay server
|
// Update the proxy mappings in the relay server
|
||||||
if proxyServer == nil {
|
if proxyRelay == nil {
|
||||||
logger.Error("Proxy server is not available")
|
logger.Error("Proxy server is not available")
|
||||||
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
|
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyServer.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations)
|
proxyRelay.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations)
|
||||||
|
|
||||||
logger.Info("Updated proxy mapping for %s:%d with %d destinations",
|
logger.Info("Updated proxy mapping for %s:%d with %d destinations",
|
||||||
request.SourceIP, request.SourcePort, len(request.Destinations))
|
request.SourceIP, request.SourcePort, len(request.Destinations))
|
||||||
@@ -878,6 +928,12 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
proxySNI.UpdateLocalSNIs(req.FullDomains)
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"status": "Local SNIs updated successfully",
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func periodicBandwidthCheck(endpoint string) {
|
func periodicBandwidthCheck(endpoint string) {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package main
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -31,16 +31,16 @@ type RouteAPIResponse struct {
|
|||||||
|
|
||||||
// SNIProxy represents the main proxy server
|
// SNIProxy represents the main proxy server
|
||||||
type SNIProxy struct {
|
type SNIProxy struct {
|
||||||
port int
|
port int
|
||||||
cache *cache.Cache
|
cache *cache.Cache
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
exitNodeName string
|
exitNodeName string
|
||||||
localProxyAddr string
|
localProxyAddr string
|
||||||
localProxyPort int
|
localProxyPort int
|
||||||
apiBaseURL string
|
remoteConfigURL string
|
||||||
|
|
||||||
// New fields for fast local SNI lookup
|
// New fields for fast local SNI lookup
|
||||||
localSNIs map[string]struct{}
|
localSNIs map[string]struct{}
|
||||||
@@ -73,7 +73,7 @@ func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
|
|||||||
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
|
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
|
|
||||||
// NewSNIProxy creates a new SNI proxy instance
|
// NewSNIProxy creates a new SNI proxy instance
|
||||||
func NewSNIProxy(port int, exitNodeName, localProxyAddr string, localProxyPort int, apiBaseURL string, localOverrides []string) (*SNIProxy, error) {
|
func NewSNIProxy(port int, remoteConfigURL, exitNodeName, localProxyAddr string, localProxyPort int, localOverrides []string) (*SNIProxy, error) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
// Create local overrides map
|
// Create local overrides map
|
||||||
@@ -85,17 +85,17 @@ func NewSNIProxy(port int, exitNodeName, localProxyAddr string, localProxyPort i
|
|||||||
}
|
}
|
||||||
|
|
||||||
proxy := &SNIProxy{
|
proxy := &SNIProxy{
|
||||||
port: port,
|
port: port,
|
||||||
cache: cache.New(3*time.Second, 10*time.Minute),
|
cache: cache.New(3*time.Second, 10*time.Minute),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
exitNodeName: exitNodeName,
|
exitNodeName: exitNodeName,
|
||||||
localProxyAddr: localProxyAddr,
|
localProxyAddr: localProxyAddr,
|
||||||
localProxyPort: localProxyPort,
|
localProxyPort: localProxyPort,
|
||||||
apiBaseURL: apiBaseURL,
|
remoteConfigURL: remoteConfigURL,
|
||||||
localSNIs: make(map[string]struct{}),
|
localSNIs: make(map[string]struct{}),
|
||||||
localOverrides: overridesMap,
|
localOverrides: overridesMap,
|
||||||
activeTunnels: make(map[string]*activeTunnel),
|
activeTunnels: make(map[string]*activeTunnel),
|
||||||
}
|
}
|
||||||
|
|
||||||
return proxy, nil
|
return proxy, nil
|
||||||
@@ -337,7 +337,7 @@ func (p *SNIProxy) getRoute(hostname string) (*RouteRecord, error) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Construct API URL
|
// Construct API URL
|
||||||
apiURL := fmt.Sprintf("%s/api/route/%s", p.apiBaseURL, hostname)
|
apiURL := fmt.Sprintf("%s/api/route/%s", p.remoteConfigURL, hostname)
|
||||||
|
|
||||||
// Create HTTP request
|
// Create HTTP request
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
|
||||||
|
|||||||
Reference in New Issue
Block a user