Relaying working

This commit is contained in:
Owen
2025-04-11 20:52:29 -04:00
parent 6b0ca9cab5
commit 0ced66e157
3 changed files with 82 additions and 202 deletions

View File

@@ -2,13 +2,12 @@ package wgtester
import (
"encoding/binary"
"fmt"
"net"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
"golang.org/x/net/ipv4"
)
const (
@@ -25,9 +24,9 @@ const (
packetSize = 13
)
// Server handles listening for connection check requests using raw sockets
// Server handles listening for connection check requests using UDP
type Server struct {
rawConn *ipv4.RawConn
conn *net.UDPConn
serverAddr string
serverPort uint16
shutdownCh chan struct{}
@@ -37,18 +36,18 @@ type Server struct {
outputPrefix string
}
// NewServer creates a new connection test server using raw sockets
// NewServer creates a new connection test server using UDP
func NewServer(serverAddr string, serverPort uint16, newtID string) *Server {
return &Server{
serverAddr: serverAddr,
serverPort: serverPort,
serverPort: serverPort + 1, // use the next port for the server
shutdownCh: make(chan struct{}),
newtID: newtID,
outputPrefix: "[WGTester] ",
}
}
// Start begins listening for connection test packets using raw sockets
// Start begins listening for connection test packets using UDP
func (s *Server) Start() error {
s.runningLock.Lock()
defer s.runningLock.Unlock()
@@ -57,30 +56,26 @@ func (s *Server) Start() error {
return nil
}
// Configure server and client for BPF filtering
server := &network.Server{
Hostname: s.serverAddr,
Addr: network.HostToAddr(s.serverAddr),
Port: s.serverPort,
//create the address to listen on
addr := net.JoinHostPort(s.serverAddr, fmt.Sprintf("%d", s.serverPort))
// Create UDP address to listen on
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return err
}
clientIP := network.GetClientIP(server.Addr.IP)
// Use the server port as our client port to match the WireGuard configuration
client := &network.PeerNet{
IP: clientIP,
Port: s.serverPort, // Use same port as server to share with WireGuard
NewtID: s.newtID,
// Create UDP connection
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
// Setup raw connection with custom BPF to filter for our magic header
rawConn := network.SetupRawConnWithCustomBPF(server, client, magicHeader)
s.rawConn = rawConn
s.conn = conn
s.isRunning = true
go s.handleConnections()
logger.Info(""+s.outputPrefix+"Server started on %s:%d", s.serverAddr, s.serverPort)
logger.Info("%sServer started on %s:%d", s.outputPrefix, s.serverAddr, s.serverPort)
return nil
}
@@ -94,8 +89,8 @@ func (s *Server) Stop() {
}
close(s.shutdownCh)
if s.rawConn != nil {
s.rawConn.Close()
if s.conn != nil {
s.conn.Close()
}
s.isRunning = false
logger.Info(s.outputPrefix + "Server stopped")
@@ -103,23 +98,22 @@ func (s *Server) Stop() {
// handleConnections processes incoming packets
func (s *Server) handleConnections() {
buffer := make([]byte, 2000) // Buffer large enough for any UDP packet
for {
select {
case <-s.shutdownCh:
return
default:
// Read packet with timeout using RawConn
err := s.rawConn.SetReadDeadline(time.Now().Add(1 * time.Second))
// Set read deadline to avoid blocking forever
err := s.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
if err != nil {
logger.Error(s.outputPrefix+"Error setting read deadline: %v", err)
continue
}
// Create buffer for the entire IP packet
payload := make([]byte, 2000) // Large enough for any UDP packet
// Read the packet
_, _, _, err = s.rawConn.ReadFrom(payload)
// Read from UDP connection
n, addr, err := s.conn.ReadFromUDP(buffer)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Just a timeout, keep going
@@ -129,26 +123,19 @@ func (s *Server) handleConnections() {
continue
}
// Extract IP and port information
srcIP, srcPort, _, _ := network.ExtractIPAndPorts(payload)
if srcIP == nil {
continue // Invalid packet
}
// Extract UDP payload
udpPayload := network.ExtractUDPPayload(payload)
if udpPayload == nil || len(udpPayload) < packetSize {
// Process packet only if it meets minimum size requirements
if n < packetSize {
continue // Too small to be our packet
}
// Check magic header
magic := binary.BigEndian.Uint32(udpPayload[0:4])
magic := binary.BigEndian.Uint32(buffer[0:4])
if magic != magicHeader {
continue // Not our packet
}
// Check packet type
packetType := udpPayload[4]
packetType := buffer[4]
if packetType != packetTypeRequest {
continue // Not a request packet
}
@@ -160,37 +147,18 @@ func (s *Server) handleConnections() {
// Change the packet type to response
responsePacket[4] = packetTypeResponse
// Copy the timestamp (for RTT calculation)
if len(udpPayload) >= 13 {
copy(responsePacket[5:13], udpPayload[5:13])
}
// Use the client's source information to send the response
peerClient := &network.PeerNet{
IP: s.rawConn.LocalAddr().(*net.IPAddr).IP,
Port: s.serverPort,
NewtID: s.newtID,
}
// Setup target server from the source of the incoming packet
server := &network.Server{
Hostname: srcIP.String(),
Addr: &net.IPAddr{IP: srcIP},
Port: srcPort,
}
copy(responsePacket[5:13], buffer[5:13])
// Log response being sent for debugging
logger.Debug(s.outputPrefix+"Sending response to %s:%d", srcIP.String(), srcPort)
logger.Debug(s.outputPrefix+"Sending response to %s", addr.String())
// Send the response packet
err = network.SendPacket(responsePacket, s.rawConn, server, peerClient)
// Send the response packet directly to the source address
_, err = s.conn.WriteToUDP(responsePacket, addr)
if err != nil {
logger.Error(s.outputPrefix+"Error sending response: %v", err)
} else {
logger.Debug(s.outputPrefix + "Response sent successfully")
}
if err != nil {
logger.Error(s.outputPrefix+"Error sending response: %v", err)
}
}
}
}