From 6b0ca9cab583bd29073e451dfd0636220bded7d8 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 3 Apr 2025 21:59:16 -0400 Subject: [PATCH] Adjust wgtester to work with bpf --- main.go | 17 +++ network/network.go | 96 ++++++++++++ wgtester/wgtester.go | 351 ++++++++++++------------------------------- 3 files changed, 213 insertions(+), 251 deletions(-) diff --git a/main.go b/main.go index faa80e6..6070b5f 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,7 @@ import ( "github.com/fosrl/newt/proxy" "github.com/fosrl/newt/websocket" "github.com/fosrl/newt/wg" + "github.com/fosrl/newt/wgtester" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" @@ -442,6 +443,7 @@ func main() { var pm *proxy.ProxyManager var connected bool var wgData WgData + var wgTesterServer *wgtester.Server if generateAndSaveKeyTo != "" { // make sure we are running on linux @@ -465,6 +467,17 @@ func main() { logger.Fatal("Failed to create WireGuard service: %v", err) } defer wgService.Close() + + wgTesterServer = wgtester.NewServer("0.0.0.0", wgService.Port, id) // TODO: maybe make this the same ip of the wg server? + err := wgTesterServer.Start() + if err != nil { + logger.Error("Failed to start WireGuard tester server: %v", err) + } else { + logger.Info("WireGuard connection testing server started on port %d", wgService.Port) + + // Make sure to stop the server on exit + defer wgTesterServer.Stop() + } } client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) { @@ -711,6 +724,10 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( wgService.Close() } + if wgTesterServer != nil { + wgTesterServer.Stop() + } + if pm != nil { pm.Stop() } diff --git a/network/network.go b/network/network.go index 0703e8b..49192ca 100644 --- a/network/network.go +++ b/network/network.go @@ -200,3 +200,99 @@ func parseForBPF(response []byte) (srcIP net.IP, srcPort uint16, dstPort uint16) dstPort = binary.BigEndian.Uint16(response[22:24]) return } + +// SetupRawConnWithCustomBPF creates an ipv4 and udp RawConn with a custom BPF program +// This allows sharing the port between WireGuard and the WGTester +func SetupRawConnWithCustomBPF(server *Server, client *PeerNet, captureMagicHeader uint32) *ipv4.RawConn { + packetConn, err := net.ListenPacket("ip4:udp", client.IP.String()) + if err != nil { + log.Fatalln("Error creating packetConn:", err) + } + + rawConn, err := ipv4.NewRawConn(packetConn) + if err != nil { + log.Fatalln("Error creating rawConn:", err) + } + + // Apply a BPF that allows capturing both WireGuard and tester packets + ApplyCustomBPF(rawConn, server, client, captureMagicHeader) + + return rawConn +} + +// ApplyCustomBPF constructs a simpler BPF program that should be more compatible +// The previous filter might have been too complex for the kernel to accept +func ApplyCustomBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet, captureMagicHeader uint32) { + const ipv4HeaderLen = 20 + const udpHeaderLen = 8 + // Magic header would be located after IP + UDP headers + const magicHeaderOffset = ipv4HeaderLen + udpHeaderLen + + // Many BPF implementations have limitations on jump offsets and program complexity + // Let's create a simpler program that just looks for: + // 1. UDP Protocol + // 2. Destination port matching our listening port or source port matching our port + // 3. We'll handle the magic header check in our application code instead + + // This creates a more basic filter that will be accepted by most kernels + bpfRaw, err := bpf.Assemble([]bpf.Instruction{ + // Load IP Protocol field (at offset 9) + bpf.LoadAbsolute{Off: 9, Size: 1}, + + // Is it UDP? (17 is UDP protocol number) + bpf.JumpIf{Cond: bpf.JumpEqual, Val: 17, SkipFalse: 5, SkipTrue: 0}, + + // Load destination port (at IP header + 2) + bpf.LoadAbsolute{Off: ipv4HeaderLen + 2, Size: 2}, + + // Is it our port? + bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 2, SkipTrue: 0}, + + // Accept packet + bpf.RetConstant{Val: 1<<(8*4) - 1}, + + // Not matching destination port, check source port + bpf.LoadAbsolute{Off: ipv4HeaderLen + 0, Size: 2}, + + // Is source port our port? + bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0}, + + // Accept packet + bpf.RetConstant{Val: 1<<(8*4) - 1}, + + // Reject packet + bpf.RetConstant{Val: 0}, + }) + + if err != nil { + log.Fatalln("Error assembling BPF:", err) + } + + err = rawConn.SetBPF(bpfRaw) + if err != nil { + log.Fatalln("Error setting BPF:", err) + } +} + +// These helper functions will make it easier to extract information from packets +// ExtractUDPPayload extracts the UDP payload from a raw IP packet +func ExtractUDPPayload(packet []byte) []byte { + if len(packet) < 28 { // IP header (20) + UDP header (8) + return nil + } + return packet[28:] +} + +// ExtractIPAndPorts extracts source/dest IP and ports from a raw IP packet +func ExtractIPAndPorts(packet []byte) (srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) { + if len(packet) < 28 { + return nil, 0, nil, 0 + } + + srcIP = net.IP(packet[12:16]) + dstIP = net.IP(packet[16:20]) + srcPort = binary.BigEndian.Uint16(packet[20:22]) + dstPort = binary.BigEndian.Uint16(packet[22:24]) + + return +} diff --git a/wgtester/wgtester.go b/wgtester/wgtester.go index cefe785..48119e8 100644 --- a/wgtester/wgtester.go +++ b/wgtester/wgtester.go @@ -1,12 +1,14 @@ package wgtester import ( - "context" "encoding/binary" - "log" "net" "sync" "time" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/network" + "golang.org/x/net/ipv4" ) const ( @@ -23,24 +25,30 @@ const ( packetSize = 13 ) -// Server handles listening for connection check requests +// Server handles listening for connection check requests using raw sockets type Server struct { - conn *net.UDPConn - listenAddr string - shutdownCh chan struct{} - isRunning bool - runningLock sync.Mutex + rawConn *ipv4.RawConn + serverAddr string + serverPort uint16 + shutdownCh chan struct{} + isRunning bool + runningLock sync.Mutex + newtID string + outputPrefix string } -// NewServer creates a new connection test server -func NewServer(listenAddr string) *Server { +// NewServer creates a new connection test server using raw sockets +func NewServer(serverAddr string, serverPort uint16, newtID string) *Server { return &Server{ - listenAddr: listenAddr, - shutdownCh: make(chan struct{}), + serverAddr: serverAddr, + serverPort: serverPort, + shutdownCh: make(chan struct{}), + newtID: newtID, + outputPrefix: "[WGTester] ", } } -// Start begins listening for connection test packets +// Start begins listening for connection test packets using raw sockets func (s *Server) Start() error { s.runningLock.Lock() defer s.runningLock.Unlock() @@ -49,20 +57,30 @@ func (s *Server) Start() error { return nil } - addr, err := net.ResolveUDPAddr("udp", s.listenAddr) - if err != nil { - return err + // Configure server and client for BPF filtering + server := &network.Server{ + Hostname: s.serverAddr, + Addr: network.HostToAddr(s.serverAddr), + Port: s.serverPort, } - s.conn, err = net.ListenUDP("udp", addr) - if err != nil { - return err + clientIP := network.GetClientIP(server.Addr.IP) + + // Use the server port as our client port to match the WireGuard configuration + client := &network.PeerNet{ + IP: clientIP, + Port: s.serverPort, // Use same port as server to share with WireGuard + NewtID: s.newtID, } + // Setup raw connection with custom BPF to filter for our magic header + rawConn := network.SetupRawConnWithCustomBPF(server, client, magicHeader) + s.rawConn = rawConn + s.isRunning = true go s.handleConnections() - log.Printf("Server listening on %s", s.listenAddr) + logger.Info(""+s.outputPrefix+"Server started on %s:%d", s.serverAddr, s.serverPort) return nil } @@ -76,272 +94,103 @@ func (s *Server) Stop() { } close(s.shutdownCh) - if s.conn != nil { - s.conn.Close() + if s.rawConn != nil { + s.rawConn.Close() } s.isRunning = false - log.Println("Server stopped") + logger.Info(s.outputPrefix + "Server stopped") } // handleConnections processes incoming packets func (s *Server) handleConnections() { - buffer := make([]byte, packetSize) - for { select { case <-s.shutdownCh: return default: - // Set read deadline to avoid blocking forever - s.conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + // Read packet with timeout using RawConn + err := s.rawConn.SetReadDeadline(time.Now().Add(1 * time.Second)) + if err != nil { + logger.Error(s.outputPrefix+"Error setting read deadline: %v", err) + continue + } - n, addr, err := s.conn.ReadFromUDP(buffer) + // Create buffer for the entire IP packet + payload := make([]byte, 2000) // Large enough for any UDP packet + + // Read the packet + _, _, _, err = s.rawConn.ReadFrom(payload) if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { // Just a timeout, keep going continue } - log.Printf("Error reading from UDP: %v", err) + logger.Error(s.outputPrefix+"Error reading from UDP: %v", err) continue } - if n != packetSize { - continue // Ignore malformed packets + // Extract IP and port information + srcIP, srcPort, _, _ := network.ExtractIPAndPorts(payload) + if srcIP == nil { + continue // Invalid packet + } + + // Extract UDP payload + udpPayload := network.ExtractUDPPayload(payload) + if udpPayload == nil || len(udpPayload) < packetSize { + continue // Too small to be our packet } // Check magic header - magic := binary.BigEndian.Uint32(buffer[0:4]) + magic := binary.BigEndian.Uint32(udpPayload[0:4]) if magic != magicHeader { continue // Not our packet } // Check packet type - packetType := buffer[4] + packetType := udpPayload[4] if packetType != packetTypeRequest { continue // Not a request packet } - // Keep the timestamp the same (for RTT calculation) - // Just change the packet type to response - buffer[4] = packetTypeResponse + // Create response packet + responsePacket := make([]byte, packetSize) + // Copy the same magic header + binary.BigEndian.PutUint32(responsePacket[0:4], magicHeader) + // Change the packet type to response + responsePacket[4] = packetTypeResponse + // Copy the timestamp (for RTT calculation) + if len(udpPayload) >= 13 { + copy(responsePacket[5:13], udpPayload[5:13]) + } - // Send response - _, err = s.conn.WriteToUDP(buffer, addr) + // Use the client's source information to send the response + peerClient := &network.PeerNet{ + IP: s.rawConn.LocalAddr().(*net.IPAddr).IP, + Port: s.serverPort, + NewtID: s.newtID, + } + + // Setup target server from the source of the incoming packet + server := &network.Server{ + Hostname: srcIP.String(), + Addr: &net.IPAddr{IP: srcIP}, + Port: srcPort, + } + + // Log response being sent for debugging + logger.Debug(s.outputPrefix+"Sending response to %s:%d", srcIP.String(), srcPort) + + // Send the response packet + err = network.SendPacket(responsePacket, s.rawConn, server, peerClient) if err != nil { - log.Printf("Error sending response: %v", err) + logger.Error(s.outputPrefix+"Error sending response: %v", err) + } else { + logger.Debug(s.outputPrefix + "Response sent successfully") + } + if err != nil { + logger.Error(s.outputPrefix+"Error sending response: %v", err) } } } } - -// Client handles checking connectivity to a server -type Client struct { - conn *net.UDPConn - serverAddr string - monitorRunning bool - monitorLock sync.Mutex - shutdownCh chan struct{} - packetInterval time.Duration - timeout time.Duration - maxAttempts int -} - -// ConnectionStatus represents the current connection state -type ConnectionStatus struct { - Connected bool - RTT time.Duration -} - -// NewClient creates a new connection test client -func NewClient(serverAddr string) (*Client, error) { - return &Client{ - serverAddr: serverAddr, - shutdownCh: make(chan struct{}), - packetInterval: 2 * time.Second, - timeout: 500 * time.Millisecond, // Timeout for individual packets - maxAttempts: 3, // Default max attempts - }, nil -} - -// SetPacketInterval changes how frequently packets are sent in monitor mode -func (c *Client) SetPacketInterval(interval time.Duration) { - c.packetInterval = interval -} - -// SetTimeout changes the timeout for waiting for responses -func (c *Client) SetTimeout(timeout time.Duration) { - c.timeout = timeout -} - -// SetMaxAttempts changes the maximum number of attempts for TestConnection -func (c *Client) SetMaxAttempts(attempts int) { - c.maxAttempts = attempts -} - -// Close cleans up client resources -func (c *Client) Close() { - c.StopMonitor() - if c.conn != nil { - c.conn.Close() - c.conn = nil - } -} - -// ensureConnection makes sure we have an active UDP connection -func (c *Client) ensureConnection() error { - if c.conn != nil { - return nil - } - - serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr) - if err != nil { - return err - } - - c.conn, err = net.DialUDP("udp", nil, serverAddr) - if err != nil { - return err - } - - return nil -} - -// TestConnection checks if the connection to the server is working -// Returns true if connected, false otherwise -func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { - if err := c.ensureConnection(); err != nil { - return false, 0 - } - - // Prepare packet buffer - packet := make([]byte, packetSize) - binary.BigEndian.PutUint32(packet[0:4], magicHeader) - packet[4] = packetTypeRequest - - // Send multiple attempts as specified - for attempt := 0; attempt < c.maxAttempts; attempt++ { - select { - case <-ctx.Done(): - return false, 0 - default: - // Add current timestamp to packet - timestamp := time.Now().UnixNano() - binary.BigEndian.PutUint64(packet[5:13], uint64(timestamp)) - - // Send the packet - _, err := c.conn.Write(packet) - if err != nil { - log.Printf("Error sending packet: %v", err) - continue - } - - // Set read deadline - c.conn.SetReadDeadline(time.Now().Add(c.timeout)) - - // Wait for response - responseBuffer := make([]byte, packetSize) - n, err := c.conn.Read(responseBuffer) - if err != nil { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - // Timeout, try next attempt - time.Sleep(100 * time.Millisecond) // Brief pause between attempts - continue - } - log.Printf("Error reading response: %v", err) - continue - } - - if n != packetSize { - continue // Malformed packet - } - - // Verify response - magic := binary.BigEndian.Uint32(responseBuffer[0:4]) - packetType := responseBuffer[4] - if magic != magicHeader || packetType != packetTypeResponse { - continue // Not our response - } - - // Extract the original timestamp and calculate RTT - sentTimestamp := int64(binary.BigEndian.Uint64(responseBuffer[5:13])) - rtt := time.Duration(time.Now().UnixNano() - sentTimestamp) - - return true, rtt - } - } - - return false, 0 -} - -// TestConnectionWithTimeout tries to test connection with a timeout -// Returns true if connected, false otherwise -func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return c.TestConnection(ctx) -} - -// MonitorCallback is the function type for connection status change callbacks -type MonitorCallback func(status ConnectionStatus) - -// StartMonitor begins monitoring the connection and calls the callback -// when the connection status changes -func (c *Client) StartMonitor(callback MonitorCallback) error { - c.monitorLock.Lock() - defer c.monitorLock.Unlock() - - if c.monitorRunning { - return nil // Already running - } - - if err := c.ensureConnection(); err != nil { - return err - } - - c.monitorRunning = true - c.shutdownCh = make(chan struct{}) - - go func() { - var lastConnected bool - firstRun := true - - ticker := time.NewTicker(c.packetInterval) - defer ticker.Stop() - - for { - select { - case <-c.shutdownCh: - return - case <-ticker.C: - ctx, cancel := context.WithTimeout(context.Background(), c.timeout) - connected, rtt := c.TestConnection(ctx) - cancel() - - // Callback if status changed or it's the first check - if connected != lastConnected || firstRun { - callback(ConnectionStatus{ - Connected: connected, - RTT: rtt, - }) - lastConnected = connected - firstRun = false - } - } - } - }() - - return nil -} - -// StopMonitor stops the connection monitoring -func (c *Client) StopMonitor() { - c.monitorLock.Lock() - defer c.monitorLock.Unlock() - - if !c.monitorRunning { - return - } - - close(c.shutdownCh) - c.monitorRunning = false -}