mirror of
https://github.com/fosrl/newt.git
synced 2026-03-04 17:56:40 +00:00
Relaying working
This commit is contained in:
@@ -193,106 +193,3 @@ func ParseResponse(response []byte) (net.IP, uint16) {
|
|||||||
port := binary.BigEndian.Uint16(response[4:6])
|
port := binary.BigEndian.Uint16(response[4:6])
|
||||||
return ip, port
|
return ip, port
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseForBPF(response []byte) (srcIP net.IP, srcPort uint16, dstPort uint16) {
|
|
||||||
srcIP = net.IP(response[12:16])
|
|
||||||
srcPort = binary.BigEndian.Uint16(response[20:22])
|
|
||||||
dstPort = binary.BigEndian.Uint16(response[22:24])
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetupRawConnWithCustomBPF creates an ipv4 and udp RawConn with a custom BPF program
|
|
||||||
// This allows sharing the port between WireGuard and the WGTester
|
|
||||||
func SetupRawConnWithCustomBPF(server *Server, client *PeerNet, captureMagicHeader uint32) *ipv4.RawConn {
|
|
||||||
packetConn, err := net.ListenPacket("ip4:udp", client.IP.String())
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalln("Error creating packetConn:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rawConn, err := ipv4.NewRawConn(packetConn)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalln("Error creating rawConn:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply a BPF that allows capturing both WireGuard and tester packets
|
|
||||||
ApplyCustomBPF(rawConn, server, client, captureMagicHeader)
|
|
||||||
|
|
||||||
return rawConn
|
|
||||||
}
|
|
||||||
|
|
||||||
// ApplyCustomBPF constructs a simpler BPF program that should be more compatible
|
|
||||||
// The previous filter might have been too complex for the kernel to accept
|
|
||||||
func ApplyCustomBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet, captureMagicHeader uint32) {
|
|
||||||
const ipv4HeaderLen = 20
|
|
||||||
const udpHeaderLen = 8
|
|
||||||
// Magic header would be located after IP + UDP headers
|
|
||||||
const magicHeaderOffset = ipv4HeaderLen + udpHeaderLen
|
|
||||||
|
|
||||||
// Many BPF implementations have limitations on jump offsets and program complexity
|
|
||||||
// Let's create a simpler program that just looks for:
|
|
||||||
// 1. UDP Protocol
|
|
||||||
// 2. Destination port matching our listening port or source port matching our port
|
|
||||||
// 3. We'll handle the magic header check in our application code instead
|
|
||||||
|
|
||||||
// This creates a more basic filter that will be accepted by most kernels
|
|
||||||
bpfRaw, err := bpf.Assemble([]bpf.Instruction{
|
|
||||||
// Load IP Protocol field (at offset 9)
|
|
||||||
bpf.LoadAbsolute{Off: 9, Size: 1},
|
|
||||||
|
|
||||||
// Is it UDP? (17 is UDP protocol number)
|
|
||||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: 17, SkipFalse: 5, SkipTrue: 0},
|
|
||||||
|
|
||||||
// Load destination port (at IP header + 2)
|
|
||||||
bpf.LoadAbsolute{Off: ipv4HeaderLen + 2, Size: 2},
|
|
||||||
|
|
||||||
// Is it our port?
|
|
||||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 2, SkipTrue: 0},
|
|
||||||
|
|
||||||
// Accept packet
|
|
||||||
bpf.RetConstant{Val: 1<<(8*4) - 1},
|
|
||||||
|
|
||||||
// Not matching destination port, check source port
|
|
||||||
bpf.LoadAbsolute{Off: ipv4HeaderLen + 0, Size: 2},
|
|
||||||
|
|
||||||
// Is source port our port?
|
|
||||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0},
|
|
||||||
|
|
||||||
// Accept packet
|
|
||||||
bpf.RetConstant{Val: 1<<(8*4) - 1},
|
|
||||||
|
|
||||||
// Reject packet
|
|
||||||
bpf.RetConstant{Val: 0},
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalln("Error assembling BPF:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rawConn.SetBPF(bpfRaw)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalln("Error setting BPF:", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// These helper functions will make it easier to extract information from packets
|
|
||||||
// ExtractUDPPayload extracts the UDP payload from a raw IP packet
|
|
||||||
func ExtractUDPPayload(packet []byte) []byte {
|
|
||||||
if len(packet) < 28 { // IP header (20) + UDP header (8)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return packet[28:]
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExtractIPAndPorts extracts source/dest IP and ports from a raw IP packet
|
|
||||||
func ExtractIPAndPorts(packet []byte) (srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) {
|
|
||||||
if len(packet) < 28 {
|
|
||||||
return nil, 0, nil, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
srcIP = net.IP(packet[12:16])
|
|
||||||
dstIP = net.IP(packet[16:20])
|
|
||||||
srcPort = binary.BigEndian.Uint16(packet[20:22])
|
|
||||||
dstPort = binary.BigEndian.Uint16(packet[22:24])
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|||||||
79
wg/wg.go
79
wg/wg.go
@@ -80,13 +80,20 @@ func NewFixedPortBind(port uint16) conn.Bind {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester
|
||||||
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||||
if maxPort < minPort {
|
if maxPort < minPort {
|
||||||
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a slice of all ports in the range
|
// We need to check port+1 as well, so adjust the max port to avoid going out of range
|
||||||
portRange := make([]uint16, maxPort-minPort+1)
|
adjustedMaxPort := maxPort - 1
|
||||||
|
if adjustedMaxPort < minPort {
|
||||||
|
return 0, fmt.Errorf("insufficient port range to find consecutive ports: min=%d, max=%d", minPort, maxPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a slice of all ports in the range (excluding the last one)
|
||||||
|
portRange := make([]uint16, adjustedMaxPort-minPort+1)
|
||||||
for i := range portRange {
|
for i := range portRange {
|
||||||
portRange[i] = minPort + uint16(i)
|
portRange[i] = minPort + uint16(i)
|
||||||
}
|
}
|
||||||
@@ -100,20 +107,35 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
|||||||
|
|
||||||
// Try each port in the randomized order
|
// Try each port in the randomized order
|
||||||
for _, port := range portRange {
|
for _, port := range portRange {
|
||||||
addr := &net.UDPAddr{
|
// Check if port is available
|
||||||
|
addr1 := &net.UDPAddr{
|
||||||
IP: net.ParseIP("127.0.0.1"),
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
Port: int(port),
|
Port: int(port),
|
||||||
}
|
}
|
||||||
conn, err := net.ListenUDP("udp", addr)
|
conn1, err1 := net.ListenUDP("udp", addr1)
|
||||||
if err != nil {
|
if err1 != nil {
|
||||||
continue // Port is in use or there was an error, try next port
|
continue // Port is in use or there was an error, try next port
|
||||||
}
|
}
|
||||||
_ = conn.SetDeadline(time.Now())
|
|
||||||
conn.Close()
|
// Check if port+1 is also available
|
||||||
|
addr2 := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: int(port + 1),
|
||||||
|
}
|
||||||
|
conn2, err2 := net.ListenUDP("udp", addr2)
|
||||||
|
if err2 != nil {
|
||||||
|
// The next port is not available, so close the first connection and try again
|
||||||
|
conn1.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both ports are available, close connections and return the first port
|
||||||
|
conn1.Close()
|
||||||
|
conn2.Close()
|
||||||
return port, nil
|
return port, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort)
|
return 0, fmt.Errorf("no available consecutive UDP ports found in range %d-%d", minPort, maxPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) {
|
func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) {
|
||||||
@@ -408,6 +430,7 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
|
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
|
||||||
|
logger.Info("Received message: %v", msg.Data)
|
||||||
var peer Peer
|
var peer Peer
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
@@ -451,8 +474,6 @@ func (s *WireGuardService) addPeer(peer Peer) error {
|
|||||||
return fmt.Errorf("failed to resolve endpoint address: %w", err)
|
return fmt.Errorf("failed to resolve endpoint address: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// make the endpoint localhost to test
|
|
||||||
|
|
||||||
peerConfig = wgtypes.PeerConfig{
|
peerConfig = wgtypes.PeerConfig{
|
||||||
PublicKey: pubKey,
|
PublicKey: pubKey,
|
||||||
AllowedIPs: allowedIPs,
|
AllowedIPs: allowedIPs,
|
||||||
@@ -482,6 +503,7 @@ func (s *WireGuardService) addPeer(peer Peer) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) {
|
func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) {
|
||||||
|
logger.Info("Received message: %v", msg.Data)
|
||||||
// parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" }
|
// parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" }
|
||||||
type RemoveRequest struct {
|
type RemoveRequest struct {
|
||||||
PublicKey string `json:"publicKey"`
|
PublicKey string `json:"publicKey"`
|
||||||
@@ -529,38 +551,34 @@ func (s *WireGuardService) removePeer(publicKey string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) {
|
func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) {
|
||||||
|
logger.Info("Received message: %v", msg.Data)
|
||||||
// Define a struct to match the incoming message structure with optional fields
|
// Define a struct to match the incoming message structure with optional fields
|
||||||
type UpdatePeerRequest struct {
|
type UpdatePeerRequest struct {
|
||||||
PublicKey string `json:"publicKey"`
|
PublicKey string `json:"publicKey"`
|
||||||
AllowedIPs []string `json:"allowedIps,omitempty"`
|
AllowedIPs []string `json:"allowedIps,omitempty"`
|
||||||
Endpoint string `json:"endpoint,omitempty"`
|
Endpoint string `json:"endpoint,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info("Error marshaling data: %v", err)
|
logger.Info("Error marshaling data: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var request UpdatePeerRequest
|
var request UpdatePeerRequest
|
||||||
if err := json.Unmarshal(jsonData, &request); err != nil {
|
if err := json.Unmarshal(jsonData, &request); err != nil {
|
||||||
logger.Info("Error unmarshaling peer data: %v", err)
|
logger.Info("Error unmarshaling peer data: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// First, get the current peer configuration to preserve any unmodified fields
|
// First, get the current peer configuration to preserve any unmodified fields
|
||||||
device, err := s.wgClient.Device(s.interfaceName)
|
device, err := s.wgClient.Device(s.interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info("Error getting WireGuard device: %v", err)
|
logger.Info("Error getting WireGuard device: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pubKey, err := wgtypes.ParseKey(request.PublicKey)
|
pubKey, err := wgtypes.ParseKey(request.PublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info("Error parsing public key: %v", err)
|
logger.Info("Error parsing public key: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the existing peer configuration
|
// Find the existing peer configuration
|
||||||
var currentPeer *wgtypes.Peer
|
var currentPeer *wgtypes.Peer
|
||||||
for _, p := range device.Peers {
|
for _, p := range device.Peers {
|
||||||
@@ -569,22 +587,30 @@ func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if currentPeer == nil {
|
if currentPeer == nil {
|
||||||
logger.Info("Peer %s not found, cannot update", request.PublicKey)
|
logger.Info("Peer %s not found, cannot update", request.PublicKey)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the update peer config
|
// Create the update peer config
|
||||||
peerConfig := wgtypes.PeerConfig{
|
peerConfig := wgtypes.PeerConfig{
|
||||||
PublicKey: pubKey,
|
PublicKey: pubKey,
|
||||||
UpdateOnly: true,
|
UpdateOnly: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Keep the default persistent keepalive of 1 second
|
// Keep the default persistent keepalive of 1 second
|
||||||
keepalive := time.Second
|
keepalive := time.Second
|
||||||
peerConfig.PersistentKeepaliveInterval = &keepalive
|
peerConfig.PersistentKeepaliveInterval = &keepalive
|
||||||
|
|
||||||
|
// Handle Endpoint field special case
|
||||||
|
// If Endpoint is included in the request but empty, we want to remove the endpoint
|
||||||
|
// If Endpoint is not included, we don't modify it
|
||||||
|
endpointSpecified := false
|
||||||
|
for key := range msg.Data.(map[string]interface{}) {
|
||||||
|
if key == "endpoint" {
|
||||||
|
endpointSpecified = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Only update AllowedIPs if provided in the request
|
// Only update AllowedIPs if provided in the request
|
||||||
if request.AllowedIPs != nil && len(request.AllowedIPs) > 0 {
|
if request.AllowedIPs != nil && len(request.AllowedIPs) > 0 {
|
||||||
var allowedIPs []net.IPNet
|
var allowedIPs []net.IPNet
|
||||||
@@ -597,18 +623,10 @@ func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) {
|
|||||||
allowedIPs = append(allowedIPs, *ipNet)
|
allowedIPs = append(allowedIPs, *ipNet)
|
||||||
}
|
}
|
||||||
peerConfig.AllowedIPs = allowedIPs
|
peerConfig.AllowedIPs = allowedIPs
|
||||||
|
peerConfig.ReplaceAllowedIPs = true
|
||||||
logger.Info("Updating AllowedIPs for peer %s", request.PublicKey)
|
logger.Info("Updating AllowedIPs for peer %s", request.PublicKey)
|
||||||
}
|
} else if endpointSpecified && request.Endpoint == "" {
|
||||||
|
peerConfig.ReplaceAllowedIPs = false
|
||||||
// Handle Endpoint field special case
|
|
||||||
// If Endpoint is included in the request but empty, we want to remove the endpoint
|
|
||||||
// If Endpoint is not included, we don't modify it
|
|
||||||
endpointSpecified := false
|
|
||||||
for key := range msg.Data.(map[string]interface{}) {
|
|
||||||
if key == "endpoint" {
|
|
||||||
endpointSpecified = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if endpointSpecified {
|
if endpointSpecified {
|
||||||
@@ -623,7 +641,6 @@ func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) {
|
|||||||
logger.Info("Updating Endpoint for peer %s to %s", request.PublicKey, request.Endpoint)
|
logger.Info("Updating Endpoint for peer %s to %s", request.PublicKey, request.Endpoint)
|
||||||
} else {
|
} else {
|
||||||
// Request contained endpoint field but it was empty/null - remove endpoint
|
// Request contained endpoint field but it was empty/null - remove endpoint
|
||||||
// To remove an endpoint in WireGuard, we set it to nil and specify ReplaceAllowedIPs
|
|
||||||
peerConfig.Endpoint = nil
|
peerConfig.Endpoint = nil
|
||||||
logger.Info("Removing Endpoint for peer %s", request.PublicKey)
|
logger.Info("Removing Endpoint for peer %s", request.PublicKey)
|
||||||
}
|
}
|
||||||
@@ -633,12 +650,10 @@ func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) {
|
|||||||
config := wgtypes.Config{
|
config := wgtypes.Config{
|
||||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
|
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
|
||||||
logger.Info("Error updating peer configuration: %v", err)
|
logger.Info("Error updating peer configuration: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Peer %s updated successfully", request.PublicKey)
|
logger.Info("Peer %s updated successfully", request.PublicKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,13 +2,12 @@ package wgtester
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/network"
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -25,9 +24,9 @@ const (
|
|||||||
packetSize = 13
|
packetSize = 13
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server handles listening for connection check requests using raw sockets
|
// Server handles listening for connection check requests using UDP
|
||||||
type Server struct {
|
type Server struct {
|
||||||
rawConn *ipv4.RawConn
|
conn *net.UDPConn
|
||||||
serverAddr string
|
serverAddr string
|
||||||
serverPort uint16
|
serverPort uint16
|
||||||
shutdownCh chan struct{}
|
shutdownCh chan struct{}
|
||||||
@@ -37,18 +36,18 @@ type Server struct {
|
|||||||
outputPrefix string
|
outputPrefix string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new connection test server using raw sockets
|
// NewServer creates a new connection test server using UDP
|
||||||
func NewServer(serverAddr string, serverPort uint16, newtID string) *Server {
|
func NewServer(serverAddr string, serverPort uint16, newtID string) *Server {
|
||||||
return &Server{
|
return &Server{
|
||||||
serverAddr: serverAddr,
|
serverAddr: serverAddr,
|
||||||
serverPort: serverPort,
|
serverPort: serverPort + 1, // use the next port for the server
|
||||||
shutdownCh: make(chan struct{}),
|
shutdownCh: make(chan struct{}),
|
||||||
newtID: newtID,
|
newtID: newtID,
|
||||||
outputPrefix: "[WGTester] ",
|
outputPrefix: "[WGTester] ",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start begins listening for connection test packets using raw sockets
|
// Start begins listening for connection test packets using UDP
|
||||||
func (s *Server) Start() error {
|
func (s *Server) Start() error {
|
||||||
s.runningLock.Lock()
|
s.runningLock.Lock()
|
||||||
defer s.runningLock.Unlock()
|
defer s.runningLock.Unlock()
|
||||||
@@ -57,30 +56,26 @@ func (s *Server) Start() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure server and client for BPF filtering
|
//create the address to listen on
|
||||||
server := &network.Server{
|
addr := net.JoinHostPort(s.serverAddr, fmt.Sprintf("%d", s.serverPort))
|
||||||
Hostname: s.serverAddr,
|
|
||||||
Addr: network.HostToAddr(s.serverAddr),
|
// Create UDP address to listen on
|
||||||
Port: s.serverPort,
|
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
clientIP := network.GetClientIP(server.Addr.IP)
|
// Create UDP connection
|
||||||
|
conn, err := net.ListenUDP("udp", udpAddr)
|
||||||
// Use the server port as our client port to match the WireGuard configuration
|
if err != nil {
|
||||||
client := &network.PeerNet{
|
return err
|
||||||
IP: clientIP,
|
|
||||||
Port: s.serverPort, // Use same port as server to share with WireGuard
|
|
||||||
NewtID: s.newtID,
|
|
||||||
}
|
}
|
||||||
|
s.conn = conn
|
||||||
// Setup raw connection with custom BPF to filter for our magic header
|
|
||||||
rawConn := network.SetupRawConnWithCustomBPF(server, client, magicHeader)
|
|
||||||
s.rawConn = rawConn
|
|
||||||
|
|
||||||
s.isRunning = true
|
s.isRunning = true
|
||||||
go s.handleConnections()
|
go s.handleConnections()
|
||||||
|
|
||||||
logger.Info(""+s.outputPrefix+"Server started on %s:%d", s.serverAddr, s.serverPort)
|
logger.Info("%sServer started on %s:%d", s.outputPrefix, s.serverAddr, s.serverPort)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,8 +89,8 @@ func (s *Server) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
close(s.shutdownCh)
|
close(s.shutdownCh)
|
||||||
if s.rawConn != nil {
|
if s.conn != nil {
|
||||||
s.rawConn.Close()
|
s.conn.Close()
|
||||||
}
|
}
|
||||||
s.isRunning = false
|
s.isRunning = false
|
||||||
logger.Info(s.outputPrefix + "Server stopped")
|
logger.Info(s.outputPrefix + "Server stopped")
|
||||||
@@ -103,23 +98,22 @@ func (s *Server) Stop() {
|
|||||||
|
|
||||||
// handleConnections processes incoming packets
|
// handleConnections processes incoming packets
|
||||||
func (s *Server) handleConnections() {
|
func (s *Server) handleConnections() {
|
||||||
|
buffer := make([]byte, 2000) // Buffer large enough for any UDP packet
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-s.shutdownCh:
|
case <-s.shutdownCh:
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
// Read packet with timeout using RawConn
|
// Set read deadline to avoid blocking forever
|
||||||
err := s.rawConn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
err := s.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(s.outputPrefix+"Error setting read deadline: %v", err)
|
logger.Error(s.outputPrefix+"Error setting read deadline: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create buffer for the entire IP packet
|
// Read from UDP connection
|
||||||
payload := make([]byte, 2000) // Large enough for any UDP packet
|
n, addr, err := s.conn.ReadFromUDP(buffer)
|
||||||
|
|
||||||
// Read the packet
|
|
||||||
_, _, _, err = s.rawConn.ReadFrom(payload)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
// Just a timeout, keep going
|
// Just a timeout, keep going
|
||||||
@@ -129,26 +123,19 @@ func (s *Server) handleConnections() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract IP and port information
|
// Process packet only if it meets minimum size requirements
|
||||||
srcIP, srcPort, _, _ := network.ExtractIPAndPorts(payload)
|
if n < packetSize {
|
||||||
if srcIP == nil {
|
|
||||||
continue // Invalid packet
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract UDP payload
|
|
||||||
udpPayload := network.ExtractUDPPayload(payload)
|
|
||||||
if udpPayload == nil || len(udpPayload) < packetSize {
|
|
||||||
continue // Too small to be our packet
|
continue // Too small to be our packet
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check magic header
|
// Check magic header
|
||||||
magic := binary.BigEndian.Uint32(udpPayload[0:4])
|
magic := binary.BigEndian.Uint32(buffer[0:4])
|
||||||
if magic != magicHeader {
|
if magic != magicHeader {
|
||||||
continue // Not our packet
|
continue // Not our packet
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check packet type
|
// Check packet type
|
||||||
packetType := udpPayload[4]
|
packetType := buffer[4]
|
||||||
if packetType != packetTypeRequest {
|
if packetType != packetTypeRequest {
|
||||||
continue // Not a request packet
|
continue // Not a request packet
|
||||||
}
|
}
|
||||||
@@ -160,37 +147,18 @@ func (s *Server) handleConnections() {
|
|||||||
// Change the packet type to response
|
// Change the packet type to response
|
||||||
responsePacket[4] = packetTypeResponse
|
responsePacket[4] = packetTypeResponse
|
||||||
// Copy the timestamp (for RTT calculation)
|
// Copy the timestamp (for RTT calculation)
|
||||||
if len(udpPayload) >= 13 {
|
copy(responsePacket[5:13], buffer[5:13])
|
||||||
copy(responsePacket[5:13], udpPayload[5:13])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use the client's source information to send the response
|
|
||||||
peerClient := &network.PeerNet{
|
|
||||||
IP: s.rawConn.LocalAddr().(*net.IPAddr).IP,
|
|
||||||
Port: s.serverPort,
|
|
||||||
NewtID: s.newtID,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup target server from the source of the incoming packet
|
|
||||||
server := &network.Server{
|
|
||||||
Hostname: srcIP.String(),
|
|
||||||
Addr: &net.IPAddr{IP: srcIP},
|
|
||||||
Port: srcPort,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log response being sent for debugging
|
// Log response being sent for debugging
|
||||||
logger.Debug(s.outputPrefix+"Sending response to %s:%d", srcIP.String(), srcPort)
|
logger.Debug(s.outputPrefix+"Sending response to %s", addr.String())
|
||||||
|
|
||||||
// Send the response packet
|
// Send the response packet directly to the source address
|
||||||
err = network.SendPacket(responsePacket, s.rawConn, server, peerClient)
|
_, err = s.conn.WriteToUDP(responsePacket, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(s.outputPrefix+"Error sending response: %v", err)
|
logger.Error(s.outputPrefix+"Error sending response: %v", err)
|
||||||
} else {
|
} else {
|
||||||
logger.Debug(s.outputPrefix + "Response sent successfully")
|
logger.Debug(s.outputPrefix + "Response sent successfully")
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
logger.Error(s.outputPrefix+"Error sending response: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user