mirror of
https://github.com/fosrl/newt.git
synced 2026-02-22 21:06:38 +00:00
Adjust wgtester to work with bpf
This commit is contained in:
17
main.go
17
main.go
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/fosrl/newt/proxy"
|
"github.com/fosrl/newt/proxy"
|
||||||
"github.com/fosrl/newt/websocket"
|
"github.com/fosrl/newt/websocket"
|
||||||
"github.com/fosrl/newt/wg"
|
"github.com/fosrl/newt/wg"
|
||||||
|
"github.com/fosrl/newt/wgtester"
|
||||||
|
|
||||||
"golang.org/x/net/icmp"
|
"golang.org/x/net/icmp"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
@@ -442,6 +443,7 @@ func main() {
|
|||||||
var pm *proxy.ProxyManager
|
var pm *proxy.ProxyManager
|
||||||
var connected bool
|
var connected bool
|
||||||
var wgData WgData
|
var wgData WgData
|
||||||
|
var wgTesterServer *wgtester.Server
|
||||||
|
|
||||||
if generateAndSaveKeyTo != "" {
|
if generateAndSaveKeyTo != "" {
|
||||||
// make sure we are running on linux
|
// make sure we are running on linux
|
||||||
@@ -465,6 +467,17 @@ func main() {
|
|||||||
logger.Fatal("Failed to create WireGuard service: %v", err)
|
logger.Fatal("Failed to create WireGuard service: %v", err)
|
||||||
}
|
}
|
||||||
defer wgService.Close()
|
defer wgService.Close()
|
||||||
|
|
||||||
|
wgTesterServer = wgtester.NewServer("0.0.0.0", wgService.Port, id) // TODO: maybe make this the same ip of the wg server?
|
||||||
|
err := wgTesterServer.Start()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to start WireGuard tester server: %v", err)
|
||||||
|
} else {
|
||||||
|
logger.Info("WireGuard connection testing server started on port %d", wgService.Port)
|
||||||
|
|
||||||
|
// Make sure to stop the server on exit
|
||||||
|
defer wgTesterServer.Stop()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
|
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
|
||||||
@@ -711,6 +724,10 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
|||||||
wgService.Close()
|
wgService.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if wgTesterServer != nil {
|
||||||
|
wgTesterServer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
if pm != nil {
|
if pm != nil {
|
||||||
pm.Stop()
|
pm.Stop()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -200,3 +200,99 @@ func parseForBPF(response []byte) (srcIP net.IP, srcPort uint16, dstPort uint16)
|
|||||||
dstPort = binary.BigEndian.Uint16(response[22:24])
|
dstPort = binary.BigEndian.Uint16(response[22:24])
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetupRawConnWithCustomBPF creates an ipv4 and udp RawConn with a custom BPF program
|
||||||
|
// This allows sharing the port between WireGuard and the WGTester
|
||||||
|
func SetupRawConnWithCustomBPF(server *Server, client *PeerNet, captureMagicHeader uint32) *ipv4.RawConn {
|
||||||
|
packetConn, err := net.ListenPacket("ip4:udp", client.IP.String())
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("Error creating packetConn:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawConn, err := ipv4.NewRawConn(packetConn)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("Error creating rawConn:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply a BPF that allows capturing both WireGuard and tester packets
|
||||||
|
ApplyCustomBPF(rawConn, server, client, captureMagicHeader)
|
||||||
|
|
||||||
|
return rawConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyCustomBPF constructs a simpler BPF program that should be more compatible
|
||||||
|
// The previous filter might have been too complex for the kernel to accept
|
||||||
|
func ApplyCustomBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet, captureMagicHeader uint32) {
|
||||||
|
const ipv4HeaderLen = 20
|
||||||
|
const udpHeaderLen = 8
|
||||||
|
// Magic header would be located after IP + UDP headers
|
||||||
|
const magicHeaderOffset = ipv4HeaderLen + udpHeaderLen
|
||||||
|
|
||||||
|
// Many BPF implementations have limitations on jump offsets and program complexity
|
||||||
|
// Let's create a simpler program that just looks for:
|
||||||
|
// 1. UDP Protocol
|
||||||
|
// 2. Destination port matching our listening port or source port matching our port
|
||||||
|
// 3. We'll handle the magic header check in our application code instead
|
||||||
|
|
||||||
|
// This creates a more basic filter that will be accepted by most kernels
|
||||||
|
bpfRaw, err := bpf.Assemble([]bpf.Instruction{
|
||||||
|
// Load IP Protocol field (at offset 9)
|
||||||
|
bpf.LoadAbsolute{Off: 9, Size: 1},
|
||||||
|
|
||||||
|
// Is it UDP? (17 is UDP protocol number)
|
||||||
|
bpf.JumpIf{Cond: bpf.JumpEqual, Val: 17, SkipFalse: 5, SkipTrue: 0},
|
||||||
|
|
||||||
|
// Load destination port (at IP header + 2)
|
||||||
|
bpf.LoadAbsolute{Off: ipv4HeaderLen + 2, Size: 2},
|
||||||
|
|
||||||
|
// Is it our port?
|
||||||
|
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 2, SkipTrue: 0},
|
||||||
|
|
||||||
|
// Accept packet
|
||||||
|
bpf.RetConstant{Val: 1<<(8*4) - 1},
|
||||||
|
|
||||||
|
// Not matching destination port, check source port
|
||||||
|
bpf.LoadAbsolute{Off: ipv4HeaderLen + 0, Size: 2},
|
||||||
|
|
||||||
|
// Is source port our port?
|
||||||
|
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0},
|
||||||
|
|
||||||
|
// Accept packet
|
||||||
|
bpf.RetConstant{Val: 1<<(8*4) - 1},
|
||||||
|
|
||||||
|
// Reject packet
|
||||||
|
bpf.RetConstant{Val: 0},
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("Error assembling BPF:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rawConn.SetBPF(bpfRaw)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("Error setting BPF:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// These helper functions will make it easier to extract information from packets
|
||||||
|
// ExtractUDPPayload extracts the UDP payload from a raw IP packet
|
||||||
|
func ExtractUDPPayload(packet []byte) []byte {
|
||||||
|
if len(packet) < 28 { // IP header (20) + UDP header (8)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return packet[28:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractIPAndPorts extracts source/dest IP and ports from a raw IP packet
|
||||||
|
func ExtractIPAndPorts(packet []byte) (srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) {
|
||||||
|
if len(packet) < 28 {
|
||||||
|
return nil, 0, nil, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP = net.IP(packet[12:16])
|
||||||
|
dstIP = net.IP(packet[16:20])
|
||||||
|
srcPort = binary.BigEndian.Uint16(packet[20:22])
|
||||||
|
dstPort = binary.BigEndian.Uint16(packet[22:24])
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
package wgtester
|
package wgtester
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"log"
|
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/newt/network"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -23,24 +25,30 @@ const (
|
|||||||
packetSize = 13
|
packetSize = 13
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server handles listening for connection check requests
|
// Server handles listening for connection check requests using raw sockets
|
||||||
type Server struct {
|
type Server struct {
|
||||||
conn *net.UDPConn
|
rawConn *ipv4.RawConn
|
||||||
listenAddr string
|
serverAddr string
|
||||||
shutdownCh chan struct{}
|
serverPort uint16
|
||||||
isRunning bool
|
shutdownCh chan struct{}
|
||||||
runningLock sync.Mutex
|
isRunning bool
|
||||||
|
runningLock sync.Mutex
|
||||||
|
newtID string
|
||||||
|
outputPrefix string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new connection test server
|
// NewServer creates a new connection test server using raw sockets
|
||||||
func NewServer(listenAddr string) *Server {
|
func NewServer(serverAddr string, serverPort uint16, newtID string) *Server {
|
||||||
return &Server{
|
return &Server{
|
||||||
listenAddr: listenAddr,
|
serverAddr: serverAddr,
|
||||||
shutdownCh: make(chan struct{}),
|
serverPort: serverPort,
|
||||||
|
shutdownCh: make(chan struct{}),
|
||||||
|
newtID: newtID,
|
||||||
|
outputPrefix: "[WGTester] ",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start begins listening for connection test packets
|
// Start begins listening for connection test packets using raw sockets
|
||||||
func (s *Server) Start() error {
|
func (s *Server) Start() error {
|
||||||
s.runningLock.Lock()
|
s.runningLock.Lock()
|
||||||
defer s.runningLock.Unlock()
|
defer s.runningLock.Unlock()
|
||||||
@@ -49,20 +57,30 @@ func (s *Server) Start() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", s.listenAddr)
|
// Configure server and client for BPF filtering
|
||||||
if err != nil {
|
server := &network.Server{
|
||||||
return err
|
Hostname: s.serverAddr,
|
||||||
|
Addr: network.HostToAddr(s.serverAddr),
|
||||||
|
Port: s.serverPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.conn, err = net.ListenUDP("udp", addr)
|
clientIP := network.GetClientIP(server.Addr.IP)
|
||||||
if err != nil {
|
|
||||||
return err
|
// Use the server port as our client port to match the WireGuard configuration
|
||||||
|
client := &network.PeerNet{
|
||||||
|
IP: clientIP,
|
||||||
|
Port: s.serverPort, // Use same port as server to share with WireGuard
|
||||||
|
NewtID: s.newtID,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Setup raw connection with custom BPF to filter for our magic header
|
||||||
|
rawConn := network.SetupRawConnWithCustomBPF(server, client, magicHeader)
|
||||||
|
s.rawConn = rawConn
|
||||||
|
|
||||||
s.isRunning = true
|
s.isRunning = true
|
||||||
go s.handleConnections()
|
go s.handleConnections()
|
||||||
|
|
||||||
log.Printf("Server listening on %s", s.listenAddr)
|
logger.Info(""+s.outputPrefix+"Server started on %s:%d", s.serverAddr, s.serverPort)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,272 +94,103 @@ func (s *Server) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
close(s.shutdownCh)
|
close(s.shutdownCh)
|
||||||
if s.conn != nil {
|
if s.rawConn != nil {
|
||||||
s.conn.Close()
|
s.rawConn.Close()
|
||||||
}
|
}
|
||||||
s.isRunning = false
|
s.isRunning = false
|
||||||
log.Println("Server stopped")
|
logger.Info(s.outputPrefix + "Server stopped")
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleConnections processes incoming packets
|
// handleConnections processes incoming packets
|
||||||
func (s *Server) handleConnections() {
|
func (s *Server) handleConnections() {
|
||||||
buffer := make([]byte, packetSize)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-s.shutdownCh:
|
case <-s.shutdownCh:
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
// Set read deadline to avoid blocking forever
|
// Read packet with timeout using RawConn
|
||||||
s.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
err := s.rawConn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(s.outputPrefix+"Error setting read deadline: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
n, addr, err := s.conn.ReadFromUDP(buffer)
|
// Create buffer for the entire IP packet
|
||||||
|
payload := make([]byte, 2000) // Large enough for any UDP packet
|
||||||
|
|
||||||
|
// Read the packet
|
||||||
|
_, _, _, err = s.rawConn.ReadFrom(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
// Just a timeout, keep going
|
// Just a timeout, keep going
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Printf("Error reading from UDP: %v", err)
|
logger.Error(s.outputPrefix+"Error reading from UDP: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if n != packetSize {
|
// Extract IP and port information
|
||||||
continue // Ignore malformed packets
|
srcIP, srcPort, _, _ := network.ExtractIPAndPorts(payload)
|
||||||
|
if srcIP == nil {
|
||||||
|
continue // Invalid packet
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract UDP payload
|
||||||
|
udpPayload := network.ExtractUDPPayload(payload)
|
||||||
|
if udpPayload == nil || len(udpPayload) < packetSize {
|
||||||
|
continue // Too small to be our packet
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check magic header
|
// Check magic header
|
||||||
magic := binary.BigEndian.Uint32(buffer[0:4])
|
magic := binary.BigEndian.Uint32(udpPayload[0:4])
|
||||||
if magic != magicHeader {
|
if magic != magicHeader {
|
||||||
continue // Not our packet
|
continue // Not our packet
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check packet type
|
// Check packet type
|
||||||
packetType := buffer[4]
|
packetType := udpPayload[4]
|
||||||
if packetType != packetTypeRequest {
|
if packetType != packetTypeRequest {
|
||||||
continue // Not a request packet
|
continue // Not a request packet
|
||||||
}
|
}
|
||||||
|
|
||||||
// Keep the timestamp the same (for RTT calculation)
|
// Create response packet
|
||||||
// Just change the packet type to response
|
responsePacket := make([]byte, packetSize)
|
||||||
buffer[4] = packetTypeResponse
|
// Copy the same magic header
|
||||||
|
binary.BigEndian.PutUint32(responsePacket[0:4], magicHeader)
|
||||||
|
// Change the packet type to response
|
||||||
|
responsePacket[4] = packetTypeResponse
|
||||||
|
// Copy the timestamp (for RTT calculation)
|
||||||
|
if len(udpPayload) >= 13 {
|
||||||
|
copy(responsePacket[5:13], udpPayload[5:13])
|
||||||
|
}
|
||||||
|
|
||||||
// Send response
|
// Use the client's source information to send the response
|
||||||
_, err = s.conn.WriteToUDP(buffer, addr)
|
peerClient := &network.PeerNet{
|
||||||
|
IP: s.rawConn.LocalAddr().(*net.IPAddr).IP,
|
||||||
|
Port: s.serverPort,
|
||||||
|
NewtID: s.newtID,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup target server from the source of the incoming packet
|
||||||
|
server := &network.Server{
|
||||||
|
Hostname: srcIP.String(),
|
||||||
|
Addr: &net.IPAddr{IP: srcIP},
|
||||||
|
Port: srcPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log response being sent for debugging
|
||||||
|
logger.Debug(s.outputPrefix+"Sending response to %s:%d", srcIP.String(), srcPort)
|
||||||
|
|
||||||
|
// Send the response packet
|
||||||
|
err = network.SendPacket(responsePacket, s.rawConn, server, peerClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error sending response: %v", err)
|
logger.Error(s.outputPrefix+"Error sending response: %v", err)
|
||||||
|
} else {
|
||||||
|
logger.Debug(s.outputPrefix + "Response sent successfully")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(s.outputPrefix+"Error sending response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client handles checking connectivity to a server
|
|
||||||
type Client struct {
|
|
||||||
conn *net.UDPConn
|
|
||||||
serverAddr string
|
|
||||||
monitorRunning bool
|
|
||||||
monitorLock sync.Mutex
|
|
||||||
shutdownCh chan struct{}
|
|
||||||
packetInterval time.Duration
|
|
||||||
timeout time.Duration
|
|
||||||
maxAttempts int
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnectionStatus represents the current connection state
|
|
||||||
type ConnectionStatus struct {
|
|
||||||
Connected bool
|
|
||||||
RTT time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClient creates a new connection test client
|
|
||||||
func NewClient(serverAddr string) (*Client, error) {
|
|
||||||
return &Client{
|
|
||||||
serverAddr: serverAddr,
|
|
||||||
shutdownCh: make(chan struct{}),
|
|
||||||
packetInterval: 2 * time.Second,
|
|
||||||
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
|
||||||
maxAttempts: 3, // Default max attempts
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPacketInterval changes how frequently packets are sent in monitor mode
|
|
||||||
func (c *Client) SetPacketInterval(interval time.Duration) {
|
|
||||||
c.packetInterval = interval
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTimeout changes the timeout for waiting for responses
|
|
||||||
func (c *Client) SetTimeout(timeout time.Duration) {
|
|
||||||
c.timeout = timeout
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
|
||||||
func (c *Client) SetMaxAttempts(attempts int) {
|
|
||||||
c.maxAttempts = attempts
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close cleans up client resources
|
|
||||||
func (c *Client) Close() {
|
|
||||||
c.StopMonitor()
|
|
||||||
if c.conn != nil {
|
|
||||||
c.conn.Close()
|
|
||||||
c.conn = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ensureConnection makes sure we have an active UDP connection
|
|
||||||
func (c *Client) ensureConnection() error {
|
|
||||||
if c.conn != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.conn, err = net.DialUDP("udp", nil, serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestConnection checks if the connection to the server is working
|
|
||||||
// Returns true if connected, false otherwise
|
|
||||||
func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
|
||||||
if err := c.ensureConnection(); err != nil {
|
|
||||||
return false, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare packet buffer
|
|
||||||
packet := make([]byte, packetSize)
|
|
||||||
binary.BigEndian.PutUint32(packet[0:4], magicHeader)
|
|
||||||
packet[4] = packetTypeRequest
|
|
||||||
|
|
||||||
// Send multiple attempts as specified
|
|
||||||
for attempt := 0; attempt < c.maxAttempts; attempt++ {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return false, 0
|
|
||||||
default:
|
|
||||||
// Add current timestamp to packet
|
|
||||||
timestamp := time.Now().UnixNano()
|
|
||||||
binary.BigEndian.PutUint64(packet[5:13], uint64(timestamp))
|
|
||||||
|
|
||||||
// Send the packet
|
|
||||||
_, err := c.conn.Write(packet)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error sending packet: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set read deadline
|
|
||||||
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
|
|
||||||
|
|
||||||
// Wait for response
|
|
||||||
responseBuffer := make([]byte, packetSize)
|
|
||||||
n, err := c.conn.Read(responseBuffer)
|
|
||||||
if err != nil {
|
|
||||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
||||||
// Timeout, try next attempt
|
|
||||||
time.Sleep(100 * time.Millisecond) // Brief pause between attempts
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
log.Printf("Error reading response: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if n != packetSize {
|
|
||||||
continue // Malformed packet
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify response
|
|
||||||
magic := binary.BigEndian.Uint32(responseBuffer[0:4])
|
|
||||||
packetType := responseBuffer[4]
|
|
||||||
if magic != magicHeader || packetType != packetTypeResponse {
|
|
||||||
continue // Not our response
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract the original timestamp and calculate RTT
|
|
||||||
sentTimestamp := int64(binary.BigEndian.Uint64(responseBuffer[5:13]))
|
|
||||||
rtt := time.Duration(time.Now().UnixNano() - sentTimestamp)
|
|
||||||
|
|
||||||
return true, rtt
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestConnectionWithTimeout tries to test connection with a timeout
|
|
||||||
// Returns true if connected, false otherwise
|
|
||||||
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
defer cancel()
|
|
||||||
return c.TestConnection(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MonitorCallback is the function type for connection status change callbacks
|
|
||||||
type MonitorCallback func(status ConnectionStatus)
|
|
||||||
|
|
||||||
// StartMonitor begins monitoring the connection and calls the callback
|
|
||||||
// when the connection status changes
|
|
||||||
func (c *Client) StartMonitor(callback MonitorCallback) error {
|
|
||||||
c.monitorLock.Lock()
|
|
||||||
defer c.monitorLock.Unlock()
|
|
||||||
|
|
||||||
if c.monitorRunning {
|
|
||||||
return nil // Already running
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.ensureConnection(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.monitorRunning = true
|
|
||||||
c.shutdownCh = make(chan struct{})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
var lastConnected bool
|
|
||||||
firstRun := true
|
|
||||||
|
|
||||||
ticker := time.NewTicker(c.packetInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.shutdownCh:
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
|
||||||
connected, rtt := c.TestConnection(ctx)
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
// Callback if status changed or it's the first check
|
|
||||||
if connected != lastConnected || firstRun {
|
|
||||||
callback(ConnectionStatus{
|
|
||||||
Connected: connected,
|
|
||||||
RTT: rtt,
|
|
||||||
})
|
|
||||||
lastConnected = connected
|
|
||||||
firstRun = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// StopMonitor stops the connection monitoring
|
|
||||||
func (c *Client) StopMonitor() {
|
|
||||||
c.monitorLock.Lock()
|
|
||||||
defer c.monitorLock.Unlock()
|
|
||||||
|
|
||||||
if !c.monitorRunning {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
close(c.shutdownCh)
|
|
||||||
c.monitorRunning = false
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user