diff --git a/relay/relay.go b/relay/relay.go index 8c02905..1635e52 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -41,6 +41,23 @@ type InitialMappings struct { Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port" } +// Packet is a simple struct to hold the packet data and sender info. +type Packet struct { + data []byte + remoteAddr *net.UDPAddr + n int +} + +// --- End Types --- + +// bufferPool allows reusing buffers to reduce allocations. +var bufferPool = sync.Pool{ + New: func() interface{} { + return make([]byte, 1500) + }, +} + +// UDPProxyServer now has a channel for incoming packets. type UDPProxyServer struct { addr string serverURL string @@ -48,24 +65,22 @@ type UDPProxyServer struct { proxyMappings sync.Map // map[string]ProxyMapping where key is "ip:port" connections sync.Map // map[string]*DestinationConn where key is destination "ip:port" publicKey wgtypes.Key + packetChan chan Packet } -type Logger interface { - Info(format string, args ...interface{}) - Error(format string, args ...interface{}) - Fatal(format string, args ...interface{}) -} - +// NewUDPProxyServer initializes the server with a buffered packet channel. func NewUDPProxyServer(addr, serverURL string, publicKey wgtypes.Key) *UDPProxyServer { return &UDPProxyServer{ - addr: addr, - serverURL: serverURL, - publicKey: publicKey, + addr: addr, + serverURL: serverURL, + publicKey: publicKey, + packetChan: make(chan Packet, 1000), } } +// Start sets up the UDP listener, worker pool, and begins reading packets. func (s *UDPProxyServer) Start() error { - // First fetch initial mappings + // Fetch initial mappings. if err := s.fetchInitialMappings(); err != nil { return fmt.Errorf("failed to fetch initial mappings: %v", err) } @@ -74,17 +89,25 @@ func (s *UDPProxyServer) Start() error { if err != nil { return err } - conn, err := net.ListenUDP("udp", udpAddr) if err != nil { return err } - s.conn = conn logger.Info("UDP server listening on %s", s.addr) - go s.handlePackets() + // Start a fixed number of worker goroutines. + workerCount := 10 + for i := 0; i < workerCount; i++ { + go s.packetWorker() + } + + // Start the goroutine that reads packets from the UDP socket. + go s.readPackets() + + // Start the idle connection cleanup routine. go s.cleanupIdleConnections() + return nil } @@ -92,73 +115,103 @@ func (s *UDPProxyServer) Stop() { s.conn.Close() } +// readPackets continuously reads from the UDP socket and pushes packets into the channel. +func (s *UDPProxyServer) readPackets() { + for { + buf := bufferPool.Get().([]byte) + n, remoteAddr, err := s.conn.ReadFromUDP(buf) + if err != nil { + logger.Error("Error reading UDP packet: %v", err) + continue + } + s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n} + } +} + +// packetWorker processes incoming packets from the channel. +func (s *UDPProxyServer) packetWorker() { + for packet := range s.packetChan { + // Determine packet type by inspecting the first byte. + if packet.n > 0 && packet.data[0] >= 1 && packet.data[0] <= 4 { + // Process as a WireGuard packet. + s.handleWireGuardPacket(packet.data, packet.remoteAddr) + } else { + // Process as a hole punch message. + var msg HolePunchMessage + if err := json.Unmarshal(packet.data, &msg); err != nil { + logger.Error("Error unmarshaling message: %v", err) + } else { + endpoint := ClientEndpoint{ + OlmID: msg.OlmID, + NewtID: msg.NewtID, + IP: packet.remoteAddr.IP.String(), + Port: packet.remoteAddr.Port, + Timestamp: time.Now().Unix(), + } + // You can call notifyServer synchronously here or dispatch further if needed. + s.notifyServer(endpoint) + } + } + // Return the buffer to the pool for reuse. + bufferPool.Put(packet.data[:1500]) + } +} + +// --- The remaining methods remain largely the same --- +// For example: fetchInitialMappings, handleWireGuardPacket, getOrCreateConnection, etc. + func (s *UDPProxyServer) fetchInitialMappings() error { body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, s.publicKey.PublicKey().String()))) - resp, err := http.Post(s.serverURL+"/gerbil/get-all-relays", "application/json", body) if err != nil { return fmt.Errorf("failed to fetch mappings: %v", err) } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return fmt.Errorf("server returned non-OK status: %d, body: %s", resp.StatusCode, string(body)) } - data, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("failed to read response body: %v", err) } - logger.Info("Received initial mappings: %s", string(data)) - var initialMappings InitialMappings if err := json.Unmarshal(data, &initialMappings); err != nil { return fmt.Errorf("failed to unmarshal initial mappings: %v", err) } - - // Store all mappings in our sync.Map + // Store mappings in our sync.Map. for key, mapping := range initialMappings.Mappings { s.proxyMappings.Store(key, mapping) } - logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings)) return nil } -func (s *UDPProxyServer) handlePackets() { - buffer := make([]byte, 1500) // Standard MTU size - for { - n, remoteAddr, err := s.conn.ReadFromUDP(buffer) - if err != nil { - logger.Error("Error reading UDP packet: %v", err) - continue - } - - // Otherwise, treat it as an incoming WireGuard or Hole Punch request - if n > 0 && buffer[0] >= 1 && buffer[0] <= 4 { - go s.handleWireGuardPacket(buffer[:n], remoteAddr) - continue - } - - // Try to handle as hole punch message - var msg HolePunchMessage - if err := json.Unmarshal(buffer[:n], &msg); err != nil { - logger.Error("Error unmarshaling message: %v", err) - continue - } - - endpoint := ClientEndpoint{ - OlmID: msg.OlmID, - NewtID: msg.NewtID, - IP: remoteAddr.IP.String(), - Port: remoteAddr.Port, - Timestamp: time.Now().Unix(), - } - - go s.notifyServer(endpoint) +// Example handleWireGuardPacket remains unchanged. +func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) { + key := remoteAddr.String() + mapping, ok := s.proxyMappings.Load(key) + if !ok { + logger.Error("No proxy mapping found for %s", key) + return + } + proxyMapping := mapping.(ProxyMapping) + destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", + proxyMapping.DestinationIP, proxyMapping.DestinationPort)) + if err != nil { + logger.Error("Failed to resolve destination address: %v", err) + return + } + conn, err := s.getOrCreateConnection(destAddr, remoteAddr) + if err != nil { + logger.Error("Failed to get/create connection: %v", err) + return + } + _, err = conn.Write(packet) + if err != nil { + logger.Error("Failed to proxy packet: %v", err) } } @@ -207,36 +260,6 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd } } -func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) { - key := remoteAddr.String() - mapping, ok := s.proxyMappings.Load(key) - if !ok { - logger.Error("No proxy mapping found for %s", key) - return - } - - proxyMapping := mapping.(ProxyMapping) - destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", - proxyMapping.DestinationIP, proxyMapping.DestinationPort)) - if err != nil { - logger.Error("Failed to resolve destination address: %v", err) - return - } - - // Get or create a connection to the destination - conn, err := s.getOrCreateConnection(destAddr, remoteAddr) - if err != nil { - logger.Error("Failed to get/create connection: %v", err) - return - } - - // Forward the packet - _, err = conn.Write(packet) - if err != nil { - logger.Error("Failed to proxy packet: %v", err) - } -} - // Add a cleanup method to periodically remove idle connections func (s *UDPProxyServer) cleanupIdleConnections() { ticker := time.NewTicker(5 * time.Minute)