mirror of
https://github.com/fosrl/newt.git
synced 2026-02-08 05:56:40 +00:00
197 lines
5.0 KiB
Go
197 lines
5.0 KiB
Go
package wgtester
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/fosrl/newt/logger"
|
|
"github.com/fosrl/newt/network"
|
|
"golang.org/x/net/ipv4"
|
|
)
|
|
|
|
const (
|
|
// Magic bytes to identify our packets
|
|
magicHeader uint32 = 0xDEADBEEF
|
|
// Request packet type
|
|
packetTypeRequest uint8 = 1
|
|
// Response packet type
|
|
packetTypeResponse uint8 = 2
|
|
// Packet format:
|
|
// - 4 bytes: magic header (0xDEADBEEF)
|
|
// - 1 byte: packet type (1 = request, 2 = response)
|
|
// - 8 bytes: timestamp (for round-trip timing)
|
|
packetSize = 13
|
|
)
|
|
|
|
// Server handles listening for connection check requests using raw sockets
|
|
type Server struct {
|
|
rawConn *ipv4.RawConn
|
|
serverAddr string
|
|
serverPort uint16
|
|
shutdownCh chan struct{}
|
|
isRunning bool
|
|
runningLock sync.Mutex
|
|
newtID string
|
|
outputPrefix string
|
|
}
|
|
|
|
// NewServer creates a new connection test server using raw sockets
|
|
func NewServer(serverAddr string, serverPort uint16, newtID string) *Server {
|
|
return &Server{
|
|
serverAddr: serverAddr,
|
|
serverPort: serverPort,
|
|
shutdownCh: make(chan struct{}),
|
|
newtID: newtID,
|
|
outputPrefix: "[WGTester] ",
|
|
}
|
|
}
|
|
|
|
// Start begins listening for connection test packets using raw sockets
|
|
func (s *Server) Start() error {
|
|
s.runningLock.Lock()
|
|
defer s.runningLock.Unlock()
|
|
|
|
if s.isRunning {
|
|
return nil
|
|
}
|
|
|
|
// Configure server and client for BPF filtering
|
|
server := &network.Server{
|
|
Hostname: s.serverAddr,
|
|
Addr: network.HostToAddr(s.serverAddr),
|
|
Port: s.serverPort,
|
|
}
|
|
|
|
clientIP := network.GetClientIP(server.Addr.IP)
|
|
|
|
// Use the server port as our client port to match the WireGuard configuration
|
|
client := &network.PeerNet{
|
|
IP: clientIP,
|
|
Port: s.serverPort, // Use same port as server to share with WireGuard
|
|
NewtID: s.newtID,
|
|
}
|
|
|
|
// Setup raw connection with custom BPF to filter for our magic header
|
|
rawConn := network.SetupRawConnWithCustomBPF(server, client, magicHeader)
|
|
s.rawConn = rawConn
|
|
|
|
s.isRunning = true
|
|
go s.handleConnections()
|
|
|
|
logger.Info(""+s.outputPrefix+"Server started on %s:%d", s.serverAddr, s.serverPort)
|
|
return nil
|
|
}
|
|
|
|
// Stop shuts down the server
|
|
func (s *Server) Stop() {
|
|
s.runningLock.Lock()
|
|
defer s.runningLock.Unlock()
|
|
|
|
if !s.isRunning {
|
|
return
|
|
}
|
|
|
|
close(s.shutdownCh)
|
|
if s.rawConn != nil {
|
|
s.rawConn.Close()
|
|
}
|
|
s.isRunning = false
|
|
logger.Info(s.outputPrefix + "Server stopped")
|
|
}
|
|
|
|
// handleConnections processes incoming packets
|
|
func (s *Server) handleConnections() {
|
|
for {
|
|
select {
|
|
case <-s.shutdownCh:
|
|
return
|
|
default:
|
|
// Read packet with timeout using RawConn
|
|
err := s.rawConn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
|
if err != nil {
|
|
logger.Error(s.outputPrefix+"Error setting read deadline: %v", err)
|
|
continue
|
|
}
|
|
|
|
// Create buffer for the entire IP packet
|
|
payload := make([]byte, 2000) // Large enough for any UDP packet
|
|
|
|
// Read the packet
|
|
_, _, _, err = s.rawConn.ReadFrom(payload)
|
|
if err != nil {
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
// Just a timeout, keep going
|
|
continue
|
|
}
|
|
logger.Error(s.outputPrefix+"Error reading from UDP: %v", err)
|
|
continue
|
|
}
|
|
|
|
// Extract IP and port information
|
|
srcIP, srcPort, _, _ := network.ExtractIPAndPorts(payload)
|
|
if srcIP == nil {
|
|
continue // Invalid packet
|
|
}
|
|
|
|
// Extract UDP payload
|
|
udpPayload := network.ExtractUDPPayload(payload)
|
|
if udpPayload == nil || len(udpPayload) < packetSize {
|
|
continue // Too small to be our packet
|
|
}
|
|
|
|
// Check magic header
|
|
magic := binary.BigEndian.Uint32(udpPayload[0:4])
|
|
if magic != magicHeader {
|
|
continue // Not our packet
|
|
}
|
|
|
|
// Check packet type
|
|
packetType := udpPayload[4]
|
|
if packetType != packetTypeRequest {
|
|
continue // Not a request packet
|
|
}
|
|
|
|
// Create response packet
|
|
responsePacket := make([]byte, packetSize)
|
|
// Copy the same magic header
|
|
binary.BigEndian.PutUint32(responsePacket[0:4], magicHeader)
|
|
// Change the packet type to response
|
|
responsePacket[4] = packetTypeResponse
|
|
// Copy the timestamp (for RTT calculation)
|
|
if len(udpPayload) >= 13 {
|
|
copy(responsePacket[5:13], udpPayload[5:13])
|
|
}
|
|
|
|
// Use the client's source information to send the response
|
|
peerClient := &network.PeerNet{
|
|
IP: s.rawConn.LocalAddr().(*net.IPAddr).IP,
|
|
Port: s.serverPort,
|
|
NewtID: s.newtID,
|
|
}
|
|
|
|
// Setup target server from the source of the incoming packet
|
|
server := &network.Server{
|
|
Hostname: srcIP.String(),
|
|
Addr: &net.IPAddr{IP: srcIP},
|
|
Port: srcPort,
|
|
}
|
|
|
|
// Log response being sent for debugging
|
|
logger.Debug(s.outputPrefix+"Sending response to %s:%d", srcIP.String(), srcPort)
|
|
|
|
// Send the response packet
|
|
err = network.SendPacket(responsePacket, s.rawConn, server, peerClient)
|
|
if err != nil {
|
|
logger.Error(s.outputPrefix+"Error sending response: %v", err)
|
|
} else {
|
|
logger.Debug(s.outputPrefix + "Response sent successfully")
|
|
}
|
|
if err != nil {
|
|
logger.Error(s.outputPrefix+"Error sending response: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}
|