mirror of
https://github.com/fosrl/newt.git
synced 2026-03-12 05:36:41 +00:00
Basic holepunch working
This commit is contained in:
159
wg/wg.go
159
wg/wg.go
@@ -5,10 +5,13 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/exp/rand"
|
||||
@@ -214,7 +217,7 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
||||
s.config = config
|
||||
|
||||
// stop the holepunch
|
||||
close(s.stopHolepunch)
|
||||
// close(s.stopHolepunch)
|
||||
|
||||
// Ensure the WireGuard interface and peers are configured
|
||||
if err := s.ensureWireguardInterface(config); err != nil {
|
||||
@@ -373,94 +376,6 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// func (s *WireGuardService) ensureMSSClamping() error {
|
||||
// // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
|
||||
// mssValue := mtuInt - 40
|
||||
|
||||
// // Rules to be managed - just the chains, we'll construct the full command separately
|
||||
// chains := []string{"INPUT", "OUTPUT", "FORWARD"}
|
||||
|
||||
// // First, try to delete any existing rules
|
||||
// for _, chain := range chains {
|
||||
// deleteCmd := exec.Command("/usr/sbin/iptables",
|
||||
// "-t", "mangle",
|
||||
// "-D", chain,
|
||||
// "-p", "tcp",
|
||||
// "--tcp-flags", "SYN,RST", "SYN",
|
||||
// "-j", "TCPMSS",
|
||||
// "--set-mss", fmt.Sprintf("%d", mssValue))
|
||||
|
||||
// logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain)
|
||||
|
||||
// // Try deletion multiple times to handle multiple existing rules
|
||||
// for i := 0; i < 3; i++ {
|
||||
// out, err := deleteCmd.CombinedOutput()
|
||||
// if err != nil {
|
||||
// // Convert exit status 1 to string for better logging
|
||||
// if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
// logger.Debug("Deletion stopped for chain %s: %v (output: %s)",
|
||||
// chain, exitErr.String(), string(out))
|
||||
// }
|
||||
// break // No more rules to delete
|
||||
// }
|
||||
// logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1)
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Then add the new rules
|
||||
// var errors []error
|
||||
// for _, chain := range chains {
|
||||
// addCmd := exec.Command("/usr/sbin/iptables",
|
||||
// "-t", "mangle",
|
||||
// "-A", chain,
|
||||
// "-p", "tcp",
|
||||
// "--tcp-flags", "SYN,RST", "SYN",
|
||||
// "-j", "TCPMSS",
|
||||
// "--set-mss", fmt.Sprintf("%d", mssValue))
|
||||
|
||||
// logger.Info("Adding MSS clamping rule for chain %s", chain)
|
||||
|
||||
// if out, err := addCmd.CombinedOutput(); err != nil {
|
||||
// errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
||||
// chain, err, string(out))
|
||||
// logger.Error(errMsg)
|
||||
// errors = append(errors, fmt.Errorf(errMsg))
|
||||
// continue
|
||||
// }
|
||||
|
||||
// // Verify the rule was added
|
||||
// checkCmd := exec.Command("/usr/sbin/iptables",
|
||||
// "-t", "mangle",
|
||||
// "-C", chain,
|
||||
// "-p", "tcp",
|
||||
// "--tcp-flags", "SYN,RST", "SYN",
|
||||
// "-j", "TCPMSS",
|
||||
// "--set-mss", fmt.Sprintf("%d", mssValue))
|
||||
|
||||
// if out, err := checkCmd.CombinedOutput(); err != nil {
|
||||
// errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
||||
// chain, err, string(out))
|
||||
// logger.Error(errMsg)
|
||||
// errors = append(errors, fmt.Errorf(errMsg))
|
||||
// continue
|
||||
// }
|
||||
|
||||
// logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain)
|
||||
// }
|
||||
|
||||
// // If we encountered any errors, return them combined
|
||||
// if len(errors) > 0 {
|
||||
// var errMsgs []string
|
||||
// for _, err := range errors {
|
||||
// errMsgs = append(errMsgs, err.Error())
|
||||
// }
|
||||
// return fmt.Errorf("MSS clamping setup encountered errors:\n%s",
|
||||
// strings.Join(errMsgs, "\n"))
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
|
||||
var peer Peer
|
||||
|
||||
@@ -681,40 +596,72 @@ func (s *WireGuardService) reportPeerBandwidth() error {
|
||||
}
|
||||
|
||||
func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error {
|
||||
// Bind to specific local port
|
||||
localAddr := &net.UDPAddr{
|
||||
Port: int(s.port),
|
||||
IP: net.IPv4zero,
|
||||
// Parse server address
|
||||
serverSplit := strings.Split(serverAddr, ":")
|
||||
if len(serverSplit) < 2 {
|
||||
return fmt.Errorf("invalid server address format, expected hostname:port")
|
||||
}
|
||||
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
serverHostname := serverSplit[0]
|
||||
serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve UDP address: %v", err)
|
||||
return fmt.Errorf("failed to parse server port: %v", err)
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP("udp", localAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bind UDP socket: %v", err)
|
||||
// Resolve server hostname to IP
|
||||
serverIPAddr := network.HostToAddr(serverHostname)
|
||||
if serverIPAddr == nil {
|
||||
return fmt.Errorf("failed to resolve server hostname")
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Get client IP based on route to server
|
||||
clientIP := network.GetClientIP(serverIPAddr.IP)
|
||||
|
||||
// Create server and client configs
|
||||
server := &network.Server{
|
||||
Hostname: serverHostname,
|
||||
Addr: serverIPAddr,
|
||||
Port: uint16(serverPort),
|
||||
}
|
||||
|
||||
client := &network.PeerNet{
|
||||
IP: clientIP,
|
||||
Port: s.port,
|
||||
NewtID: s.newtId,
|
||||
}
|
||||
|
||||
// Setup raw connection with BPF filtering
|
||||
rawConn := network.SetupRawConn(server, client)
|
||||
defer rawConn.Close()
|
||||
|
||||
// Create JSON payload
|
||||
payload := struct {
|
||||
NewtID string `json:"newtId"`
|
||||
}{
|
||||
NewtID: s.newtId,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %v", err)
|
||||
}
|
||||
|
||||
_, err = conn.WriteToUDP(data, remoteAddr)
|
||||
// Send the packet using the raw connection
|
||||
err = network.SendDataPacket(payload, rawConn, server, client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send UDP packet: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("Sent UDP hole punch to %s", serverAddr)
|
||||
// logger.Info("Sent UDP hole punch to %s", serverAddr)
|
||||
|
||||
// // Wait for response if needed
|
||||
// response, err := network.RecvDataPacket(rawConn, server, client)
|
||||
// if err != nil {
|
||||
// if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||
// return fmt.Errorf("connection to %s timed out", serverAddr)
|
||||
// }
|
||||
// return fmt.Errorf("error receiving response: %v", err)
|
||||
// }
|
||||
|
||||
// // Process response if needed
|
||||
// if len(response) > 0 {
|
||||
// logger.Info("Received response from server")
|
||||
// }
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user