Update to use new packages

This commit is contained in:
Owen
2025-11-15 16:14:40 -05:00
parent 972c9a9760
commit c71c6e0b1a
9 changed files with 1314 additions and 291 deletions

View File

@@ -2,7 +2,6 @@ package wgnetstack
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
@@ -16,14 +15,12 @@ import (
"sync"
"time"
"github.com/fosrl/newt/bind"
"github.com/fosrl/newt/holepunch"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/netstack2"
"github.com/fosrl/newt/network"
"github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/websocket"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
@@ -66,22 +63,20 @@ type PeerReading struct {
}
type WireGuardService struct {
interfaceName string
mtu int
client *websocket.Client
config WgConfig
key wgtypes.Key
keyFilePath string
newtId string
lastReadings map[string]PeerReading
mu sync.Mutex
Port uint16
stopHolepunch chan struct{}
host string
serverPubKey string
holePunchEndpoint string
token string
stopGetConfig func()
interfaceName string
mtu int
client *websocket.Client
config WgConfig
key wgtypes.Key
keyFilePath string
newtId string
lastReadings map[string]PeerReading
mu sync.Mutex
Port uint16
host string
serverPubKey string
token string
stopGetConfig func()
// Netstack fields
tun tun.Device
tnet *netstack2.Net
@@ -95,6 +90,9 @@ type WireGuardService struct {
// Proxy manager for tunnel
proxyManager *proxy.ProxyManager
TunnelIP string
// Shared bind and holepunch manager
sharedBind *bind.SharedBind
holePunchManager *holepunch.Manager
}
// GetProxyManager returns the proxy manager for this WireGuardService
@@ -118,24 +116,6 @@ func (s *WireGuardService) RemoveProxyTarget(proto, listenIP string, port int) e
return s.proxyManager.RemoveTarget(proto, listenIP, port)
}
// Add this type definition
type fixedPortBind struct {
port uint16
conn.Bind
}
func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
// Ignore the requested port and use our fixed port
return b.Bind.Open(b.port)
}
func NewFixedPortBind(port uint16) conn.Bind {
return &fixedPortBind{
port: port,
Bind: conn.NewDefaultBind(),
}
}
// find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
if maxPort < minPort {
@@ -215,6 +195,28 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
return nil, fmt.Errorf("error finding available port: %v", err)
}
// Create shared UDP socket for both holepunch and WireGuard
localAddr := &net.UDPAddr{
Port: int(port),
IP: net.IPv4zero,
}
udpConn, err := net.ListenUDP("udp", localAddr)
if err != nil {
return nil, fmt.Errorf("failed to create UDP socket: %v", err)
}
sharedBind, err := bind.New(udpConn)
if err != nil {
udpConn.Close()
return nil, fmt.Errorf("failed to create shared bind: %v", err)
}
// Add a reference for the hole punch manager (creator already has one reference for WireGuard)
sharedBind.AddRef()
logger.Info("Created shared UDP socket on port %d (refcount: %d)", port, sharedBind.GetRefCount())
// Parse DNS addresses
dnsAddrs := []netip.Addr{netip.MustParseAddr(dns)}
@@ -227,12 +229,16 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
newtId: newtId,
host: host,
lastReadings: make(map[string]PeerReading),
stopHolepunch: make(chan struct{}),
Port: port,
dns: dnsAddrs,
proxyManager: proxy.NewProxyManagerWithoutTNet(),
sharedBind: sharedBind,
}
// Create the holepunch manager with ResolveDomain function
// We'll need to pass a domain resolver function
service.holePunchManager = holepunch.NewManager(sharedBind, newtId)
// Register websocket handlers
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
@@ -344,10 +350,15 @@ func (s *WireGuardService) Close(rm bool) {
s.stopGetConfig = nil
}
// Stop hole punch manager
if s.holePunchManager != nil {
s.holePunchManager.Stop()
}
s.mu.Lock()
defer s.mu.Unlock()
// Close WireGuard device first - this will automatically close the TUN device
// Close WireGuard device first - this will call sharedBind.Close() which releases WireGuard's reference
if s.device != nil {
s.device.Close()
s.device = nil
@@ -360,28 +371,22 @@ func (s *WireGuardService) Close(rm bool) {
if s.tun != nil {
s.tun = nil // Don't call tun.Close() here since device.Close() already closed it
}
}
func (s *WireGuardService) StartHolepunch(serverPubKey string, endpoint string) {
// if the device is already created dont start a new holepunch
if s.device != nil {
return
// Release the hole punch reference to the shared bind
if s.sharedBind != nil {
// Release hole punch reference (WireGuard already released its reference via device.Close())
logger.Debug("Releasing shared bind (refcount before release: %d)", s.sharedBind.GetRefCount())
s.sharedBind.Release()
s.sharedBind = nil
logger.Info("Released shared UDP bind")
}
s.serverPubKey = serverPubKey
s.holePunchEndpoint = endpoint
logger.Debug("Starting UDP hole punch to %s", s.holePunchEndpoint)
// Create a new stop channel for this holepunch session
s.stopHolepunch = make(chan struct{})
// start the UDP holepunch
go s.keepSendingUDPHolePunch(s.holePunchEndpoint)
}
func (s *WireGuardService) SetToken(token string) {
s.token = token
if s.holePunchManager != nil {
s.holePunchManager.SetToken(token)
}
}
// GetNetstackNet returns the netstack network interface for use by other components
@@ -412,6 +417,19 @@ func (s *WireGuardService) SetOnNetstackClose(callback func()) {
s.onNetstackClose = callback
}
// StartHolepunch starts hole punching to a specific endpoint
func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string) {
if s.holePunchManager == nil {
logger.Warn("Hole punch manager not initialized")
return
}
logger.Info("Starting hole punch to %s with public key: %s", endpoint, publicKey)
if err := s.holePunchManager.StartSingleEndpoint(endpoint, publicKey); err != nil {
logger.Warn("Failed to start hole punch: %v", err)
}
}
func (s *WireGuardService) LoadRemoteConfig() error {
if s.stopGetConfig != nil {
s.stopGetConfig()
@@ -485,10 +503,9 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// Parse the IP address and CIDR mask
tunnelIP := netip.MustParseAddr(parts[0])
// stop the holepunch its a channel
if s.stopHolepunch != nil {
close(s.stopHolepunch)
s.stopHolepunch = nil
// Stop any ongoing hole punch operations
if s.holePunchManager != nil {
s.holePunchManager.Stop()
}
// Parse the IP address from the config
@@ -512,8 +529,8 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// s.proxyManager.SetTNet(s.tnet)
s.TunnelIP = tunnelIP.String()
// Create WireGuard device
s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger(
// Create WireGuard device using the shared bind
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
device.LogLevelSilent, // Use silent logging by default - could be made configurable
"wireguard: ",
))
@@ -946,171 +963,6 @@ func (s *WireGuardService) reportPeerBandwidth() error {
return nil
}
func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error {
if s.serverPubKey == "" || s.token == "" {
logger.Debug("Server public key or token not set, skipping UDP hole punch")
return nil
}
// Parse server address
serverSplit := strings.Split(serverAddr, ":")
if len(serverSplit) < 2 {
return fmt.Errorf("invalid server address format, expected hostname:port")
}
serverHostname := serverSplit[0]
serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16)
if err != nil {
return fmt.Errorf("failed to parse server port: %v", err)
}
// Resolve server hostname to IP
serverIPAddr := network.HostToAddr(serverHostname)
if serverIPAddr == nil {
return fmt.Errorf("failed to resolve server hostname")
}
// Create local UDP address using the same port as WireGuard
localAddr := &net.UDPAddr{
IP: net.IPv4zero,
Port: int(s.Port),
}
// Create remote server address
remoteAddr := &net.UDPAddr{
IP: serverIPAddr.IP,
Port: int(serverPort),
}
// Create UDP connection bound to the same port as WireGuard
conn, err := net.DialUDP("udp", localAddr, remoteAddr)
if err != nil {
return fmt.Errorf("failed to create netstack UDP connection: %v", err)
}
defer conn.Close()
// Create JSON payload
payload := struct {
NewtID string `json:"newtId"`
Token string `json:"token"`
}{
NewtID: s.newtId,
Token: s.token,
}
// Convert payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %v", err)
}
// Encrypt the payload using the server's WireGuard public key
encryptedPayload, err := s.encryptPayload(payloadBytes)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %v", err)
}
// Convert encrypted payload to JSON
jsonData, err := json.Marshal(encryptedPayload)
if err != nil {
return fmt.Errorf("failed to marshal encrypted payload: %v", err)
}
// Send the encrypted packet using the netstack UDP connection
_, err = conn.Write(jsonData)
if err != nil {
return fmt.Errorf("failed to send UDP packet: %v", err)
}
logger.Debug("Sent UDP hole punch to %s via netstack", remoteAddr.String())
return nil
}
func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) {
// Generate an ephemeral keypair for this message
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
}
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
// Parse the server's public key
serverPubKey, err := wgtypes.ParseKey(s.serverPubKey)
if err != nil {
return nil, fmt.Errorf("failed to parse server public key: %v", err)
}
// Use X25519 for key exchange (replacing deprecated ScalarMult)
var ephPrivKeyFixed [32]byte
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
// Perform X25519 key exchange
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create an AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %v", err)
}
// Encrypt the payload
ciphertext := aead.Seal(nil, nonce, payload, nil)
// Prepare the final encrypted message
encryptedMsg := struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}{
EphemeralPublicKey: ephemeralPublicKey.String(),
Nonce: nonce,
Ciphertext: ciphertext,
}
return encryptedMsg, nil
}
func (s *WireGuardService) keepSendingUDPHolePunch(host string) {
logger.Info("Starting UDP hole punch routine to %s:21820", host)
// send initial hole punch
if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
logger.Debug("Failed to send initial UDP hole punch: %v", err)
}
ticker := time.NewTicker(3 * time.Second)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-s.stopHolepunch:
logger.Info("Stopping UDP holepunch")
return
case <-timeout.C:
logger.Info("UDP holepunch routine timed out after 15 seconds")
return
case <-ticker.C:
if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
logger.Debug("Failed to send UDP hole punch: %v", err)
}
}
}
}
func (s *WireGuardService) updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
var replace = false
for _, t := range targetData.Targets {
@@ -1242,8 +1094,8 @@ func (s *WireGuardService) ReplaceNetstack() error {
s.tun = newTun
s.tnet = newTnet
// Create new WireGuard device with same port
s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger(
// Create new WireGuard device with same shared bind
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
device.LogLevelSilent,
"wireguard: ",
))