Big speed increase

This commit is contained in:
Owen
2025-02-23 18:43:37 -05:00
parent f7c0bb9135
commit 093a4c21f2

View File

@@ -41,6 +41,23 @@ type InitialMappings struct {
Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port" Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port"
} }
// Packet is a simple struct to hold the packet data and sender info.
type Packet struct {
data []byte
remoteAddr *net.UDPAddr
n int
}
// --- End Types ---
// bufferPool allows reusing buffers to reduce allocations.
var bufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, 1500)
},
}
// UDPProxyServer now has a channel for incoming packets.
type UDPProxyServer struct { type UDPProxyServer struct {
addr string addr string
serverURL string serverURL string
@@ -48,24 +65,22 @@ type UDPProxyServer struct {
proxyMappings sync.Map // map[string]ProxyMapping where key is "ip:port" proxyMappings sync.Map // map[string]ProxyMapping where key is "ip:port"
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port" connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
publicKey wgtypes.Key publicKey wgtypes.Key
packetChan chan Packet
} }
type Logger interface { // NewUDPProxyServer initializes the server with a buffered packet channel.
Info(format string, args ...interface{})
Error(format string, args ...interface{})
Fatal(format string, args ...interface{})
}
func NewUDPProxyServer(addr, serverURL string, publicKey wgtypes.Key) *UDPProxyServer { func NewUDPProxyServer(addr, serverURL string, publicKey wgtypes.Key) *UDPProxyServer {
return &UDPProxyServer{ return &UDPProxyServer{
addr: addr, addr: addr,
serverURL: serverURL, serverURL: serverURL,
publicKey: publicKey, publicKey: publicKey,
packetChan: make(chan Packet, 1000),
} }
} }
// Start sets up the UDP listener, worker pool, and begins reading packets.
func (s *UDPProxyServer) Start() error { func (s *UDPProxyServer) Start() error {
// First fetch initial mappings // Fetch initial mappings.
if err := s.fetchInitialMappings(); err != nil { if err := s.fetchInitialMappings(); err != nil {
return fmt.Errorf("failed to fetch initial mappings: %v", err) return fmt.Errorf("failed to fetch initial mappings: %v", err)
} }
@@ -74,17 +89,25 @@ func (s *UDPProxyServer) Start() error {
if err != nil { if err != nil {
return err return err
} }
conn, err := net.ListenUDP("udp", udpAddr) conn, err := net.ListenUDP("udp", udpAddr)
if err != nil { if err != nil {
return err return err
} }
s.conn = conn s.conn = conn
logger.Info("UDP server listening on %s", s.addr) logger.Info("UDP server listening on %s", s.addr)
go s.handlePackets() // Start a fixed number of worker goroutines.
workerCount := 10
for i := 0; i < workerCount; i++ {
go s.packetWorker()
}
// Start the goroutine that reads packets from the UDP socket.
go s.readPackets()
// Start the idle connection cleanup routine.
go s.cleanupIdleConnections() go s.cleanupIdleConnections()
return nil return nil
} }
@@ -92,73 +115,103 @@ func (s *UDPProxyServer) Stop() {
s.conn.Close() s.conn.Close()
} }
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
func (s *UDPProxyServer) readPackets() {
for {
buf := bufferPool.Get().([]byte)
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
if err != nil {
logger.Error("Error reading UDP packet: %v", err)
continue
}
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
}
}
// packetWorker processes incoming packets from the channel.
func (s *UDPProxyServer) packetWorker() {
for packet := range s.packetChan {
// Determine packet type by inspecting the first byte.
if packet.n > 0 && packet.data[0] >= 1 && packet.data[0] <= 4 {
// Process as a WireGuard packet.
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
} else {
// Process as a hole punch message.
var msg HolePunchMessage
if err := json.Unmarshal(packet.data, &msg); err != nil {
logger.Error("Error unmarshaling message: %v", err)
} else {
endpoint := ClientEndpoint{
OlmID: msg.OlmID,
NewtID: msg.NewtID,
IP: packet.remoteAddr.IP.String(),
Port: packet.remoteAddr.Port,
Timestamp: time.Now().Unix(),
}
// You can call notifyServer synchronously here or dispatch further if needed.
s.notifyServer(endpoint)
}
}
// Return the buffer to the pool for reuse.
bufferPool.Put(packet.data[:1500])
}
}
// --- The remaining methods remain largely the same ---
// For example: fetchInitialMappings, handleWireGuardPacket, getOrCreateConnection, etc.
func (s *UDPProxyServer) fetchInitialMappings() error { func (s *UDPProxyServer) fetchInitialMappings() error {
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, s.publicKey.PublicKey().String()))) body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, s.publicKey.PublicKey().String())))
resp, err := http.Post(s.serverURL+"/gerbil/get-all-relays", "application/json", body) resp, err := http.Post(s.serverURL+"/gerbil/get-all-relays", "application/json", body)
if err != nil { if err != nil {
return fmt.Errorf("failed to fetch mappings: %v", err) return fmt.Errorf("failed to fetch mappings: %v", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("server returned non-OK status: %d, body: %s", return fmt.Errorf("server returned non-OK status: %d, body: %s",
resp.StatusCode, string(body)) resp.StatusCode, string(body))
} }
data, err := io.ReadAll(resp.Body) data, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return fmt.Errorf("failed to read response body: %v", err) return fmt.Errorf("failed to read response body: %v", err)
} }
logger.Info("Received initial mappings: %s", string(data)) logger.Info("Received initial mappings: %s", string(data))
var initialMappings InitialMappings var initialMappings InitialMappings
if err := json.Unmarshal(data, &initialMappings); err != nil { if err := json.Unmarshal(data, &initialMappings); err != nil {
return fmt.Errorf("failed to unmarshal initial mappings: %v", err) return fmt.Errorf("failed to unmarshal initial mappings: %v", err)
} }
// Store mappings in our sync.Map.
// Store all mappings in our sync.Map
for key, mapping := range initialMappings.Mappings { for key, mapping := range initialMappings.Mappings {
s.proxyMappings.Store(key, mapping) s.proxyMappings.Store(key, mapping)
} }
logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings)) logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings))
return nil return nil
} }
func (s *UDPProxyServer) handlePackets() { // Example handleWireGuardPacket remains unchanged.
buffer := make([]byte, 1500) // Standard MTU size func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
for { key := remoteAddr.String()
n, remoteAddr, err := s.conn.ReadFromUDP(buffer) mapping, ok := s.proxyMappings.Load(key)
if err != nil { if !ok {
logger.Error("Error reading UDP packet: %v", err) logger.Error("No proxy mapping found for %s", key)
continue return
} }
proxyMapping := mapping.(ProxyMapping)
// Otherwise, treat it as an incoming WireGuard or Hole Punch request destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d",
if n > 0 && buffer[0] >= 1 && buffer[0] <= 4 { proxyMapping.DestinationIP, proxyMapping.DestinationPort))
go s.handleWireGuardPacket(buffer[:n], remoteAddr) if err != nil {
continue logger.Error("Failed to resolve destination address: %v", err)
} return
}
// Try to handle as hole punch message conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
var msg HolePunchMessage if err != nil {
if err := json.Unmarshal(buffer[:n], &msg); err != nil { logger.Error("Failed to get/create connection: %v", err)
logger.Error("Error unmarshaling message: %v", err) return
continue }
} _, err = conn.Write(packet)
if err != nil {
endpoint := ClientEndpoint{ logger.Error("Failed to proxy packet: %v", err)
OlmID: msg.OlmID,
NewtID: msg.NewtID,
IP: remoteAddr.IP.String(),
Port: remoteAddr.Port,
Timestamp: time.Now().Unix(),
}
go s.notifyServer(endpoint)
} }
} }
@@ -207,36 +260,6 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
} }
} }
func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
key := remoteAddr.String()
mapping, ok := s.proxyMappings.Load(key)
if !ok {
logger.Error("No proxy mapping found for %s", key)
return
}
proxyMapping := mapping.(ProxyMapping)
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d",
proxyMapping.DestinationIP, proxyMapping.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
return
}
// Get or create a connection to the destination
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
if err != nil {
logger.Error("Failed to get/create connection: %v", err)
return
}
// Forward the packet
_, err = conn.Write(packet)
if err != nil {
logger.Error("Failed to proxy packet: %v", err)
}
}
// Add a cleanup method to periodically remove idle connections // Add a cleanup method to periodically remove idle connections
func (s *UDPProxyServer) cleanupIdleConnections() { func (s *UDPProxyServer) cleanupIdleConnections() {
ticker := time.NewTicker(5 * time.Minute) ticker := time.NewTicker(5 * time.Minute)