Basic holepunch working

This commit is contained in:
Owen
2025-02-23 00:59:51 -05:00
parent 8795c57b2e
commit f6429b6eee
5 changed files with 295 additions and 106 deletions

159
wg/wg.go
View File

@@ -5,10 +5,13 @@ import (
"fmt"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
"github.com/fosrl/newt/websocket"
"github.com/vishvananda/netlink"
"golang.org/x/exp/rand"
@@ -214,7 +217,7 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
s.config = config
// stop the holepunch
close(s.stopHolepunch)
// close(s.stopHolepunch)
// Ensure the WireGuard interface and peers are configured
if err := s.ensureWireguardInterface(config); err != nil {
@@ -373,94 +376,6 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
return nil
}
// func (s *WireGuardService) ensureMSSClamping() error {
// // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
// mssValue := mtuInt - 40
// // Rules to be managed - just the chains, we'll construct the full command separately
// chains := []string{"INPUT", "OUTPUT", "FORWARD"}
// // First, try to delete any existing rules
// for _, chain := range chains {
// deleteCmd := exec.Command("/usr/sbin/iptables",
// "-t", "mangle",
// "-D", chain,
// "-p", "tcp",
// "--tcp-flags", "SYN,RST", "SYN",
// "-j", "TCPMSS",
// "--set-mss", fmt.Sprintf("%d", mssValue))
// logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain)
// // Try deletion multiple times to handle multiple existing rules
// for i := 0; i < 3; i++ {
// out, err := deleteCmd.CombinedOutput()
// if err != nil {
// // Convert exit status 1 to string for better logging
// if exitErr, ok := err.(*exec.ExitError); ok {
// logger.Debug("Deletion stopped for chain %s: %v (output: %s)",
// chain, exitErr.String(), string(out))
// }
// break // No more rules to delete
// }
// logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1)
// }
// }
// // Then add the new rules
// var errors []error
// for _, chain := range chains {
// addCmd := exec.Command("/usr/sbin/iptables",
// "-t", "mangle",
// "-A", chain,
// "-p", "tcp",
// "--tcp-flags", "SYN,RST", "SYN",
// "-j", "TCPMSS",
// "--set-mss", fmt.Sprintf("%d", mssValue))
// logger.Info("Adding MSS clamping rule for chain %s", chain)
// if out, err := addCmd.CombinedOutput(); err != nil {
// errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
// chain, err, string(out))
// logger.Error(errMsg)
// errors = append(errors, fmt.Errorf(errMsg))
// continue
// }
// // Verify the rule was added
// checkCmd := exec.Command("/usr/sbin/iptables",
// "-t", "mangle",
// "-C", chain,
// "-p", "tcp",
// "--tcp-flags", "SYN,RST", "SYN",
// "-j", "TCPMSS",
// "--set-mss", fmt.Sprintf("%d", mssValue))
// if out, err := checkCmd.CombinedOutput(); err != nil {
// errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
// chain, err, string(out))
// logger.Error(errMsg)
// errors = append(errors, fmt.Errorf(errMsg))
// continue
// }
// logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain)
// }
// // If we encountered any errors, return them combined
// if len(errors) > 0 {
// var errMsgs []string
// for _, err := range errors {
// errMsgs = append(errMsgs, err.Error())
// }
// return fmt.Errorf("MSS clamping setup encountered errors:\n%s",
// strings.Join(errMsgs, "\n"))
// }
// return nil
// }
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
var peer Peer
@@ -681,40 +596,72 @@ func (s *WireGuardService) reportPeerBandwidth() error {
}
func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error {
// Bind to specific local port
localAddr := &net.UDPAddr{
Port: int(s.port),
IP: net.IPv4zero,
// Parse server address
serverSplit := strings.Split(serverAddr, ":")
if len(serverSplit) < 2 {
return fmt.Errorf("invalid server address format, expected hostname:port")
}
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
serverHostname := serverSplit[0]
serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16)
if err != nil {
return fmt.Errorf("failed to resolve UDP address: %v", err)
return fmt.Errorf("failed to parse server port: %v", err)
}
conn, err := net.ListenUDP("udp", localAddr)
if err != nil {
return fmt.Errorf("failed to bind UDP socket: %v", err)
// Resolve server hostname to IP
serverIPAddr := network.HostToAddr(serverHostname)
if serverIPAddr == nil {
return fmt.Errorf("failed to resolve server hostname")
}
defer conn.Close()
// Get client IP based on route to server
clientIP := network.GetClientIP(serverIPAddr.IP)
// Create server and client configs
server := &network.Server{
Hostname: serverHostname,
Addr: serverIPAddr,
Port: uint16(serverPort),
}
client := &network.PeerNet{
IP: clientIP,
Port: s.port,
NewtID: s.newtId,
}
// Setup raw connection with BPF filtering
rawConn := network.SetupRawConn(server, client)
defer rawConn.Close()
// Create JSON payload
payload := struct {
NewtID string `json:"newtId"`
}{
NewtID: s.newtId,
}
data, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %v", err)
}
_, err = conn.WriteToUDP(data, remoteAddr)
// Send the packet using the raw connection
err = network.SendDataPacket(payload, rawConn, server, client)
if err != nil {
return fmt.Errorf("failed to send UDP packet: %v", err)
}
logger.Info("Sent UDP hole punch to %s", serverAddr)
// logger.Info("Sent UDP hole punch to %s", serverAddr)
// // Wait for response if needed
// response, err := network.RecvDataPacket(rawConn, server, client)
// if err != nil {
// if err, ok := err.(net.Error); ok && err.Timeout() {
// return fmt.Errorf("connection to %s timed out", serverAddr)
// }
// return fmt.Errorf("error receiving response: %v", err)
// }
// // Process response if needed
// if len(response) > 0 {
// logger.Info("Received response from server")
// }
return nil
}