From 0ced66e157cc2b2a0b7ac4f444eb7afe21a76d58 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 11 Apr 2025 20:52:29 -0400 Subject: [PATCH] Relaying working --- network/network.go | 103 ------------------------------------------- wg/wg.go | 79 +++++++++++++++++++-------------- wgtester/wgtester.go | 102 +++++++++++++++--------------------------- 3 files changed, 82 insertions(+), 202 deletions(-) diff --git a/network/network.go b/network/network.go index 49192ca..e359219 100644 --- a/network/network.go +++ b/network/network.go @@ -193,106 +193,3 @@ func ParseResponse(response []byte) (net.IP, uint16) { port := binary.BigEndian.Uint16(response[4:6]) return ip, port } - -func parseForBPF(response []byte) (srcIP net.IP, srcPort uint16, dstPort uint16) { - srcIP = net.IP(response[12:16]) - srcPort = binary.BigEndian.Uint16(response[20:22]) - dstPort = binary.BigEndian.Uint16(response[22:24]) - return -} - -// SetupRawConnWithCustomBPF creates an ipv4 and udp RawConn with a custom BPF program -// This allows sharing the port between WireGuard and the WGTester -func SetupRawConnWithCustomBPF(server *Server, client *PeerNet, captureMagicHeader uint32) *ipv4.RawConn { - packetConn, err := net.ListenPacket("ip4:udp", client.IP.String()) - if err != nil { - log.Fatalln("Error creating packetConn:", err) - } - - rawConn, err := ipv4.NewRawConn(packetConn) - if err != nil { - log.Fatalln("Error creating rawConn:", err) - } - - // Apply a BPF that allows capturing both WireGuard and tester packets - ApplyCustomBPF(rawConn, server, client, captureMagicHeader) - - return rawConn -} - -// ApplyCustomBPF constructs a simpler BPF program that should be more compatible -// The previous filter might have been too complex for the kernel to accept -func ApplyCustomBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet, captureMagicHeader uint32) { - const ipv4HeaderLen = 20 - const udpHeaderLen = 8 - // Magic header would be located after IP + UDP headers - const magicHeaderOffset = ipv4HeaderLen + udpHeaderLen - - // Many BPF implementations have limitations on jump offsets and program complexity - // Let's create a simpler program that just looks for: - // 1. UDP Protocol - // 2. Destination port matching our listening port or source port matching our port - // 3. We'll handle the magic header check in our application code instead - - // This creates a more basic filter that will be accepted by most kernels - bpfRaw, err := bpf.Assemble([]bpf.Instruction{ - // Load IP Protocol field (at offset 9) - bpf.LoadAbsolute{Off: 9, Size: 1}, - - // Is it UDP? (17 is UDP protocol number) - bpf.JumpIf{Cond: bpf.JumpEqual, Val: 17, SkipFalse: 5, SkipTrue: 0}, - - // Load destination port (at IP header + 2) - bpf.LoadAbsolute{Off: ipv4HeaderLen + 2, Size: 2}, - - // Is it our port? - bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 2, SkipTrue: 0}, - - // Accept packet - bpf.RetConstant{Val: 1<<(8*4) - 1}, - - // Not matching destination port, check source port - bpf.LoadAbsolute{Off: ipv4HeaderLen + 0, Size: 2}, - - // Is source port our port? - bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0}, - - // Accept packet - bpf.RetConstant{Val: 1<<(8*4) - 1}, - - // Reject packet - bpf.RetConstant{Val: 0}, - }) - - if err != nil { - log.Fatalln("Error assembling BPF:", err) - } - - err = rawConn.SetBPF(bpfRaw) - if err != nil { - log.Fatalln("Error setting BPF:", err) - } -} - -// These helper functions will make it easier to extract information from packets -// ExtractUDPPayload extracts the UDP payload from a raw IP packet -func ExtractUDPPayload(packet []byte) []byte { - if len(packet) < 28 { // IP header (20) + UDP header (8) - return nil - } - return packet[28:] -} - -// ExtractIPAndPorts extracts source/dest IP and ports from a raw IP packet -func ExtractIPAndPorts(packet []byte) (srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) { - if len(packet) < 28 { - return nil, 0, nil, 0 - } - - srcIP = net.IP(packet[12:16]) - dstIP = net.IP(packet[16:20]) - srcPort = binary.BigEndian.Uint16(packet[20:22]) - dstPort = binary.BigEndian.Uint16(packet[22:24]) - - return -} diff --git a/wg/wg.go b/wg/wg.go index b879c9c..8095606 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -80,13 +80,20 @@ func NewFixedPortBind(port uint16) conn.Bind { } } +// find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { if maxPort < minPort { return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) } - // Create a slice of all ports in the range - portRange := make([]uint16, maxPort-minPort+1) + // We need to check port+1 as well, so adjust the max port to avoid going out of range + adjustedMaxPort := maxPort - 1 + if adjustedMaxPort < minPort { + return 0, fmt.Errorf("insufficient port range to find consecutive ports: min=%d, max=%d", minPort, maxPort) + } + + // Create a slice of all ports in the range (excluding the last one) + portRange := make([]uint16, adjustedMaxPort-minPort+1) for i := range portRange { portRange[i] = minPort + uint16(i) } @@ -100,20 +107,35 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { // Try each port in the randomized order for _, port := range portRange { - addr := &net.UDPAddr{ + // Check if port is available + addr1 := &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: int(port), } - conn, err := net.ListenUDP("udp", addr) - if err != nil { + conn1, err1 := net.ListenUDP("udp", addr1) + if err1 != nil { continue // Port is in use or there was an error, try next port } - _ = conn.SetDeadline(time.Now()) - conn.Close() + + // Check if port+1 is also available + addr2 := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: int(port + 1), + } + conn2, err2 := net.ListenUDP("udp", addr2) + if err2 != nil { + // The next port is not available, so close the first connection and try again + conn1.Close() + continue + } + + // Both ports are available, close connections and return the first port + conn1.Close() + conn2.Close() return port, nil } - return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort) + return 0, fmt.Errorf("no available consecutive UDP ports found in range %d-%d", minPort, maxPort) } func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) { @@ -408,6 +430,7 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { } func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) { + logger.Info("Received message: %v", msg.Data) var peer Peer jsonData, err := json.Marshal(msg.Data) @@ -451,8 +474,6 @@ func (s *WireGuardService) addPeer(peer Peer) error { return fmt.Errorf("failed to resolve endpoint address: %w", err) } - // make the endpoint localhost to test - peerConfig = wgtypes.PeerConfig{ PublicKey: pubKey, AllowedIPs: allowedIPs, @@ -482,6 +503,7 @@ func (s *WireGuardService) addPeer(peer Peer) error { } func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) { + logger.Info("Received message: %v", msg.Data) // parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" } type RemoveRequest struct { PublicKey string `json:"publicKey"` @@ -529,38 +551,34 @@ func (s *WireGuardService) removePeer(publicKey string) error { } func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) { + logger.Info("Received message: %v", msg.Data) // Define a struct to match the incoming message structure with optional fields type UpdatePeerRequest struct { PublicKey string `json:"publicKey"` AllowedIPs []string `json:"allowedIps,omitempty"` Endpoint string `json:"endpoint,omitempty"` } - jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Info("Error marshaling data: %v", err) return } - var request UpdatePeerRequest if err := json.Unmarshal(jsonData, &request); err != nil { logger.Info("Error unmarshaling peer data: %v", err) return } - // First, get the current peer configuration to preserve any unmodified fields device, err := s.wgClient.Device(s.interfaceName) if err != nil { logger.Info("Error getting WireGuard device: %v", err) return } - pubKey, err := wgtypes.ParseKey(request.PublicKey) if err != nil { logger.Info("Error parsing public key: %v", err) return } - // Find the existing peer configuration var currentPeer *wgtypes.Peer for _, p := range device.Peers { @@ -569,22 +587,30 @@ func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) { break } } - if currentPeer == nil { logger.Info("Peer %s not found, cannot update", request.PublicKey) return } - // Create the update peer config peerConfig := wgtypes.PeerConfig{ PublicKey: pubKey, UpdateOnly: true, } - // Keep the default persistent keepalive of 1 second keepalive := time.Second peerConfig.PersistentKeepaliveInterval = &keepalive + // Handle Endpoint field special case + // If Endpoint is included in the request but empty, we want to remove the endpoint + // If Endpoint is not included, we don't modify it + endpointSpecified := false + for key := range msg.Data.(map[string]interface{}) { + if key == "endpoint" { + endpointSpecified = true + break + } + } + // Only update AllowedIPs if provided in the request if request.AllowedIPs != nil && len(request.AllowedIPs) > 0 { var allowedIPs []net.IPNet @@ -597,18 +623,10 @@ func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) { allowedIPs = append(allowedIPs, *ipNet) } peerConfig.AllowedIPs = allowedIPs + peerConfig.ReplaceAllowedIPs = true logger.Info("Updating AllowedIPs for peer %s", request.PublicKey) - } - - // Handle Endpoint field special case - // If Endpoint is included in the request but empty, we want to remove the endpoint - // If Endpoint is not included, we don't modify it - endpointSpecified := false - for key := range msg.Data.(map[string]interface{}) { - if key == "endpoint" { - endpointSpecified = true - break - } + } else if endpointSpecified && request.Endpoint == "" { + peerConfig.ReplaceAllowedIPs = false } if endpointSpecified { @@ -623,7 +641,6 @@ func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) { logger.Info("Updating Endpoint for peer %s to %s", request.PublicKey, request.Endpoint) } else { // Request contained endpoint field but it was empty/null - remove endpoint - // To remove an endpoint in WireGuard, we set it to nil and specify ReplaceAllowedIPs peerConfig.Endpoint = nil logger.Info("Removing Endpoint for peer %s", request.PublicKey) } @@ -633,12 +650,10 @@ func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) { config := wgtypes.Config{ Peers: []wgtypes.PeerConfig{peerConfig}, } - if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { logger.Info("Error updating peer configuration: %v", err) return } - logger.Info("Peer %s updated successfully", request.PublicKey) } diff --git a/wgtester/wgtester.go b/wgtester/wgtester.go index 48119e8..b302fd4 100644 --- a/wgtester/wgtester.go +++ b/wgtester/wgtester.go @@ -2,13 +2,12 @@ package wgtester import ( "encoding/binary" + "fmt" "net" "sync" "time" "github.com/fosrl/newt/logger" - "github.com/fosrl/newt/network" - "golang.org/x/net/ipv4" ) const ( @@ -25,9 +24,9 @@ const ( packetSize = 13 ) -// Server handles listening for connection check requests using raw sockets +// Server handles listening for connection check requests using UDP type Server struct { - rawConn *ipv4.RawConn + conn *net.UDPConn serverAddr string serverPort uint16 shutdownCh chan struct{} @@ -37,18 +36,18 @@ type Server struct { outputPrefix string } -// NewServer creates a new connection test server using raw sockets +// NewServer creates a new connection test server using UDP func NewServer(serverAddr string, serverPort uint16, newtID string) *Server { return &Server{ serverAddr: serverAddr, - serverPort: serverPort, + serverPort: serverPort + 1, // use the next port for the server shutdownCh: make(chan struct{}), newtID: newtID, outputPrefix: "[WGTester] ", } } -// Start begins listening for connection test packets using raw sockets +// Start begins listening for connection test packets using UDP func (s *Server) Start() error { s.runningLock.Lock() defer s.runningLock.Unlock() @@ -57,30 +56,26 @@ func (s *Server) Start() error { return nil } - // Configure server and client for BPF filtering - server := &network.Server{ - Hostname: s.serverAddr, - Addr: network.HostToAddr(s.serverAddr), - Port: s.serverPort, + //create the address to listen on + addr := net.JoinHostPort(s.serverAddr, fmt.Sprintf("%d", s.serverPort)) + + // Create UDP address to listen on + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return err } - clientIP := network.GetClientIP(server.Addr.IP) - - // Use the server port as our client port to match the WireGuard configuration - client := &network.PeerNet{ - IP: clientIP, - Port: s.serverPort, // Use same port as server to share with WireGuard - NewtID: s.newtID, + // Create UDP connection + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return err } - - // Setup raw connection with custom BPF to filter for our magic header - rawConn := network.SetupRawConnWithCustomBPF(server, client, magicHeader) - s.rawConn = rawConn + s.conn = conn s.isRunning = true go s.handleConnections() - logger.Info(""+s.outputPrefix+"Server started on %s:%d", s.serverAddr, s.serverPort) + logger.Info("%sServer started on %s:%d", s.outputPrefix, s.serverAddr, s.serverPort) return nil } @@ -94,8 +89,8 @@ func (s *Server) Stop() { } close(s.shutdownCh) - if s.rawConn != nil { - s.rawConn.Close() + if s.conn != nil { + s.conn.Close() } s.isRunning = false logger.Info(s.outputPrefix + "Server stopped") @@ -103,23 +98,22 @@ func (s *Server) Stop() { // handleConnections processes incoming packets func (s *Server) handleConnections() { + buffer := make([]byte, 2000) // Buffer large enough for any UDP packet + for { select { case <-s.shutdownCh: return default: - // Read packet with timeout using RawConn - err := s.rawConn.SetReadDeadline(time.Now().Add(1 * time.Second)) + // Set read deadline to avoid blocking forever + err := s.conn.SetReadDeadline(time.Now().Add(1 * time.Second)) if err != nil { logger.Error(s.outputPrefix+"Error setting read deadline: %v", err) continue } - // Create buffer for the entire IP packet - payload := make([]byte, 2000) // Large enough for any UDP packet - - // Read the packet - _, _, _, err = s.rawConn.ReadFrom(payload) + // Read from UDP connection + n, addr, err := s.conn.ReadFromUDP(buffer) if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { // Just a timeout, keep going @@ -129,26 +123,19 @@ func (s *Server) handleConnections() { continue } - // Extract IP and port information - srcIP, srcPort, _, _ := network.ExtractIPAndPorts(payload) - if srcIP == nil { - continue // Invalid packet - } - - // Extract UDP payload - udpPayload := network.ExtractUDPPayload(payload) - if udpPayload == nil || len(udpPayload) < packetSize { + // Process packet only if it meets minimum size requirements + if n < packetSize { continue // Too small to be our packet } // Check magic header - magic := binary.BigEndian.Uint32(udpPayload[0:4]) + magic := binary.BigEndian.Uint32(buffer[0:4]) if magic != magicHeader { continue // Not our packet } // Check packet type - packetType := udpPayload[4] + packetType := buffer[4] if packetType != packetTypeRequest { continue // Not a request packet } @@ -160,37 +147,18 @@ func (s *Server) handleConnections() { // Change the packet type to response responsePacket[4] = packetTypeResponse // Copy the timestamp (for RTT calculation) - if len(udpPayload) >= 13 { - copy(responsePacket[5:13], udpPayload[5:13]) - } - - // Use the client's source information to send the response - peerClient := &network.PeerNet{ - IP: s.rawConn.LocalAddr().(*net.IPAddr).IP, - Port: s.serverPort, - NewtID: s.newtID, - } - - // Setup target server from the source of the incoming packet - server := &network.Server{ - Hostname: srcIP.String(), - Addr: &net.IPAddr{IP: srcIP}, - Port: srcPort, - } + copy(responsePacket[5:13], buffer[5:13]) // Log response being sent for debugging - logger.Debug(s.outputPrefix+"Sending response to %s:%d", srcIP.String(), srcPort) + logger.Debug(s.outputPrefix+"Sending response to %s", addr.String()) - // Send the response packet - err = network.SendPacket(responsePacket, s.rawConn, server, peerClient) + // Send the response packet directly to the source address + _, err = s.conn.WriteToUDP(responsePacket, addr) if err != nil { logger.Error(s.outputPrefix+"Error sending response: %v", err) } else { logger.Debug(s.outputPrefix + "Response sent successfully") } - if err != nil { - logger.Error(s.outputPrefix+"Error sending response: %v", err) - } } } }