mirror of
https://github.com/fosrl/gerbil.git
synced 2026-03-10 20:56:41 +00:00
Big speed increase
This commit is contained in:
187
relay/relay.go
187
relay/relay.go
@@ -41,6 +41,23 @@ type InitialMappings struct {
|
|||||||
Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port"
|
Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Packet is a simple struct to hold the packet data and sender info.
|
||||||
|
type Packet struct {
|
||||||
|
data []byte
|
||||||
|
remoteAddr *net.UDPAddr
|
||||||
|
n int
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- End Types ---
|
||||||
|
|
||||||
|
// bufferPool allows reusing buffers to reduce allocations.
|
||||||
|
var bufferPool = sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return make([]byte, 1500)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// UDPProxyServer now has a channel for incoming packets.
|
||||||
type UDPProxyServer struct {
|
type UDPProxyServer struct {
|
||||||
addr string
|
addr string
|
||||||
serverURL string
|
serverURL string
|
||||||
@@ -48,24 +65,22 @@ type UDPProxyServer struct {
|
|||||||
proxyMappings sync.Map // map[string]ProxyMapping where key is "ip:port"
|
proxyMappings sync.Map // map[string]ProxyMapping where key is "ip:port"
|
||||||
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
|
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
|
||||||
publicKey wgtypes.Key
|
publicKey wgtypes.Key
|
||||||
|
packetChan chan Packet
|
||||||
}
|
}
|
||||||
|
|
||||||
type Logger interface {
|
// NewUDPProxyServer initializes the server with a buffered packet channel.
|
||||||
Info(format string, args ...interface{})
|
|
||||||
Error(format string, args ...interface{})
|
|
||||||
Fatal(format string, args ...interface{})
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewUDPProxyServer(addr, serverURL string, publicKey wgtypes.Key) *UDPProxyServer {
|
func NewUDPProxyServer(addr, serverURL string, publicKey wgtypes.Key) *UDPProxyServer {
|
||||||
return &UDPProxyServer{
|
return &UDPProxyServer{
|
||||||
addr: addr,
|
addr: addr,
|
||||||
serverURL: serverURL,
|
serverURL: serverURL,
|
||||||
publicKey: publicKey,
|
publicKey: publicKey,
|
||||||
|
packetChan: make(chan Packet, 1000),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start sets up the UDP listener, worker pool, and begins reading packets.
|
||||||
func (s *UDPProxyServer) Start() error {
|
func (s *UDPProxyServer) Start() error {
|
||||||
// First fetch initial mappings
|
// Fetch initial mappings.
|
||||||
if err := s.fetchInitialMappings(); err != nil {
|
if err := s.fetchInitialMappings(); err != nil {
|
||||||
return fmt.Errorf("failed to fetch initial mappings: %v", err)
|
return fmt.Errorf("failed to fetch initial mappings: %v", err)
|
||||||
}
|
}
|
||||||
@@ -74,17 +89,25 @@ func (s *UDPProxyServer) Start() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := net.ListenUDP("udp", udpAddr)
|
conn, err := net.ListenUDP("udp", udpAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.conn = conn
|
s.conn = conn
|
||||||
logger.Info("UDP server listening on %s", s.addr)
|
logger.Info("UDP server listening on %s", s.addr)
|
||||||
|
|
||||||
go s.handlePackets()
|
// Start a fixed number of worker goroutines.
|
||||||
|
workerCount := 10
|
||||||
|
for i := 0; i < workerCount; i++ {
|
||||||
|
go s.packetWorker()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the goroutine that reads packets from the UDP socket.
|
||||||
|
go s.readPackets()
|
||||||
|
|
||||||
|
// Start the idle connection cleanup routine.
|
||||||
go s.cleanupIdleConnections()
|
go s.cleanupIdleConnections()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -92,73 +115,103 @@ func (s *UDPProxyServer) Stop() {
|
|||||||
s.conn.Close()
|
s.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
|
||||||
|
func (s *UDPProxyServer) readPackets() {
|
||||||
|
for {
|
||||||
|
buf := bufferPool.Get().([]byte)
|
||||||
|
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error reading UDP packet: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// packetWorker processes incoming packets from the channel.
|
||||||
|
func (s *UDPProxyServer) packetWorker() {
|
||||||
|
for packet := range s.packetChan {
|
||||||
|
// Determine packet type by inspecting the first byte.
|
||||||
|
if packet.n > 0 && packet.data[0] >= 1 && packet.data[0] <= 4 {
|
||||||
|
// Process as a WireGuard packet.
|
||||||
|
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
|
||||||
|
} else {
|
||||||
|
// Process as a hole punch message.
|
||||||
|
var msg HolePunchMessage
|
||||||
|
if err := json.Unmarshal(packet.data, &msg); err != nil {
|
||||||
|
logger.Error("Error unmarshaling message: %v", err)
|
||||||
|
} else {
|
||||||
|
endpoint := ClientEndpoint{
|
||||||
|
OlmID: msg.OlmID,
|
||||||
|
NewtID: msg.NewtID,
|
||||||
|
IP: packet.remoteAddr.IP.String(),
|
||||||
|
Port: packet.remoteAddr.Port,
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
// You can call notifyServer synchronously here or dispatch further if needed.
|
||||||
|
s.notifyServer(endpoint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Return the buffer to the pool for reuse.
|
||||||
|
bufferPool.Put(packet.data[:1500])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- The remaining methods remain largely the same ---
|
||||||
|
// For example: fetchInitialMappings, handleWireGuardPacket, getOrCreateConnection, etc.
|
||||||
|
|
||||||
func (s *UDPProxyServer) fetchInitialMappings() error {
|
func (s *UDPProxyServer) fetchInitialMappings() error {
|
||||||
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, s.publicKey.PublicKey().String())))
|
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, s.publicKey.PublicKey().String())))
|
||||||
|
|
||||||
resp, err := http.Post(s.serverURL+"/gerbil/get-all-relays", "application/json", body)
|
resp, err := http.Post(s.serverURL+"/gerbil/get-all-relays", "application/json", body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to fetch mappings: %v", err)
|
return fmt.Errorf("failed to fetch mappings: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
return fmt.Errorf("server returned non-OK status: %d, body: %s",
|
return fmt.Errorf("server returned non-OK status: %d, body: %s",
|
||||||
resp.StatusCode, string(body))
|
resp.StatusCode, string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := io.ReadAll(resp.Body)
|
data, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to read response body: %v", err)
|
return fmt.Errorf("failed to read response body: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Received initial mappings: %s", string(data))
|
logger.Info("Received initial mappings: %s", string(data))
|
||||||
|
|
||||||
var initialMappings InitialMappings
|
var initialMappings InitialMappings
|
||||||
if err := json.Unmarshal(data, &initialMappings); err != nil {
|
if err := json.Unmarshal(data, &initialMappings); err != nil {
|
||||||
return fmt.Errorf("failed to unmarshal initial mappings: %v", err)
|
return fmt.Errorf("failed to unmarshal initial mappings: %v", err)
|
||||||
}
|
}
|
||||||
|
// Store mappings in our sync.Map.
|
||||||
// Store all mappings in our sync.Map
|
|
||||||
for key, mapping := range initialMappings.Mappings {
|
for key, mapping := range initialMappings.Mappings {
|
||||||
s.proxyMappings.Store(key, mapping)
|
s.proxyMappings.Store(key, mapping)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings))
|
logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UDPProxyServer) handlePackets() {
|
// Example handleWireGuardPacket remains unchanged.
|
||||||
buffer := make([]byte, 1500) // Standard MTU size
|
func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
|
||||||
for {
|
key := remoteAddr.String()
|
||||||
n, remoteAddr, err := s.conn.ReadFromUDP(buffer)
|
mapping, ok := s.proxyMappings.Load(key)
|
||||||
if err != nil {
|
if !ok {
|
||||||
logger.Error("Error reading UDP packet: %v", err)
|
logger.Error("No proxy mapping found for %s", key)
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
|
proxyMapping := mapping.(ProxyMapping)
|
||||||
// Otherwise, treat it as an incoming WireGuard or Hole Punch request
|
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d",
|
||||||
if n > 0 && buffer[0] >= 1 && buffer[0] <= 4 {
|
proxyMapping.DestinationIP, proxyMapping.DestinationPort))
|
||||||
go s.handleWireGuardPacket(buffer[:n], remoteAddr)
|
if err != nil {
|
||||||
continue
|
logger.Error("Failed to resolve destination address: %v", err)
|
||||||
}
|
return
|
||||||
|
}
|
||||||
// Try to handle as hole punch message
|
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
||||||
var msg HolePunchMessage
|
if err != nil {
|
||||||
if err := json.Unmarshal(buffer[:n], &msg); err != nil {
|
logger.Error("Failed to get/create connection: %v", err)
|
||||||
logger.Error("Error unmarshaling message: %v", err)
|
return
|
||||||
continue
|
}
|
||||||
}
|
_, err = conn.Write(packet)
|
||||||
|
if err != nil {
|
||||||
endpoint := ClientEndpoint{
|
logger.Error("Failed to proxy packet: %v", err)
|
||||||
OlmID: msg.OlmID,
|
|
||||||
NewtID: msg.NewtID,
|
|
||||||
IP: remoteAddr.IP.String(),
|
|
||||||
Port: remoteAddr.Port,
|
|
||||||
Timestamp: time.Now().Unix(),
|
|
||||||
}
|
|
||||||
|
|
||||||
go s.notifyServer(endpoint)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,36 +260,6 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
|
|
||||||
key := remoteAddr.String()
|
|
||||||
mapping, ok := s.proxyMappings.Load(key)
|
|
||||||
if !ok {
|
|
||||||
logger.Error("No proxy mapping found for %s", key)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyMapping := mapping.(ProxyMapping)
|
|
||||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d",
|
|
||||||
proxyMapping.DestinationIP, proxyMapping.DestinationPort))
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to resolve destination address: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get or create a connection to the destination
|
|
||||||
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to get/create connection: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward the packet
|
|
||||||
_, err = conn.Write(packet)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to proxy packet: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add a cleanup method to periodically remove idle connections
|
// Add a cleanup method to periodically remove idle connections
|
||||||
func (s *UDPProxyServer) cleanupIdleConnections() {
|
func (s *UDPProxyServer) cleanupIdleConnections() {
|
||||||
ticker := time.NewTicker(5 * time.Minute)
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
|
|||||||
Reference in New Issue
Block a user