mirror of
https://github.com/fosrl/newt.git
synced 2026-03-01 16:26:40 +00:00
Update to use new packages
This commit is contained in:
@@ -2,7 +2,6 @@ package wgnetstack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
@@ -16,14 +15,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/bind"
|
||||
"github.com/fosrl/newt/holepunch"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/netstack2"
|
||||
"github.com/fosrl/newt/network"
|
||||
"github.com/fosrl/newt/proxy"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
@@ -66,22 +63,20 @@ type PeerReading struct {
|
||||
}
|
||||
|
||||
type WireGuardService struct {
|
||||
interfaceName string
|
||||
mtu int
|
||||
client *websocket.Client
|
||||
config WgConfig
|
||||
key wgtypes.Key
|
||||
keyFilePath string
|
||||
newtId string
|
||||
lastReadings map[string]PeerReading
|
||||
mu sync.Mutex
|
||||
Port uint16
|
||||
stopHolepunch chan struct{}
|
||||
host string
|
||||
serverPubKey string
|
||||
holePunchEndpoint string
|
||||
token string
|
||||
stopGetConfig func()
|
||||
interfaceName string
|
||||
mtu int
|
||||
client *websocket.Client
|
||||
config WgConfig
|
||||
key wgtypes.Key
|
||||
keyFilePath string
|
||||
newtId string
|
||||
lastReadings map[string]PeerReading
|
||||
mu sync.Mutex
|
||||
Port uint16
|
||||
host string
|
||||
serverPubKey string
|
||||
token string
|
||||
stopGetConfig func()
|
||||
// Netstack fields
|
||||
tun tun.Device
|
||||
tnet *netstack2.Net
|
||||
@@ -95,6 +90,9 @@ type WireGuardService struct {
|
||||
// Proxy manager for tunnel
|
||||
proxyManager *proxy.ProxyManager
|
||||
TunnelIP string
|
||||
// Shared bind and holepunch manager
|
||||
sharedBind *bind.SharedBind
|
||||
holePunchManager *holepunch.Manager
|
||||
}
|
||||
|
||||
// GetProxyManager returns the proxy manager for this WireGuardService
|
||||
@@ -118,24 +116,6 @@ func (s *WireGuardService) RemoveProxyTarget(proto, listenIP string, port int) e
|
||||
return s.proxyManager.RemoveTarget(proto, listenIP, port)
|
||||
}
|
||||
|
||||
// Add this type definition
|
||||
type fixedPortBind struct {
|
||||
port uint16
|
||||
conn.Bind
|
||||
}
|
||||
|
||||
func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
// Ignore the requested port and use our fixed port
|
||||
return b.Bind.Open(b.port)
|
||||
}
|
||||
|
||||
func NewFixedPortBind(port uint16) conn.Bind {
|
||||
return &fixedPortBind{
|
||||
port: port,
|
||||
Bind: conn.NewDefaultBind(),
|
||||
}
|
||||
}
|
||||
|
||||
// find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester
|
||||
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||
if maxPort < minPort {
|
||||
@@ -215,6 +195,28 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
|
||||
return nil, fmt.Errorf("error finding available port: %v", err)
|
||||
}
|
||||
|
||||
// Create shared UDP socket for both holepunch and WireGuard
|
||||
localAddr := &net.UDPAddr{
|
||||
Port: int(port),
|
||||
IP: net.IPv4zero,
|
||||
}
|
||||
|
||||
udpConn, err := net.ListenUDP("udp", localAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create UDP socket: %v", err)
|
||||
}
|
||||
|
||||
sharedBind, err := bind.New(udpConn)
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
return nil, fmt.Errorf("failed to create shared bind: %v", err)
|
||||
}
|
||||
|
||||
// Add a reference for the hole punch manager (creator already has one reference for WireGuard)
|
||||
sharedBind.AddRef()
|
||||
|
||||
logger.Info("Created shared UDP socket on port %d (refcount: %d)", port, sharedBind.GetRefCount())
|
||||
|
||||
// Parse DNS addresses
|
||||
dnsAddrs := []netip.Addr{netip.MustParseAddr(dns)}
|
||||
|
||||
@@ -227,12 +229,16 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
|
||||
newtId: newtId,
|
||||
host: host,
|
||||
lastReadings: make(map[string]PeerReading),
|
||||
stopHolepunch: make(chan struct{}),
|
||||
Port: port,
|
||||
dns: dnsAddrs,
|
||||
proxyManager: proxy.NewProxyManagerWithoutTNet(),
|
||||
sharedBind: sharedBind,
|
||||
}
|
||||
|
||||
// Create the holepunch manager with ResolveDomain function
|
||||
// We'll need to pass a domain resolver function
|
||||
service.holePunchManager = holepunch.NewManager(sharedBind, newtId)
|
||||
|
||||
// Register websocket handlers
|
||||
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
|
||||
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
|
||||
@@ -344,10 +350,15 @@ func (s *WireGuardService) Close(rm bool) {
|
||||
s.stopGetConfig = nil
|
||||
}
|
||||
|
||||
// Stop hole punch manager
|
||||
if s.holePunchManager != nil {
|
||||
s.holePunchManager.Stop()
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Close WireGuard device first - this will automatically close the TUN device
|
||||
// Close WireGuard device first - this will call sharedBind.Close() which releases WireGuard's reference
|
||||
if s.device != nil {
|
||||
s.device.Close()
|
||||
s.device = nil
|
||||
@@ -360,28 +371,22 @@ func (s *WireGuardService) Close(rm bool) {
|
||||
if s.tun != nil {
|
||||
s.tun = nil // Don't call tun.Close() here since device.Close() already closed it
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) StartHolepunch(serverPubKey string, endpoint string) {
|
||||
// if the device is already created dont start a new holepunch
|
||||
if s.device != nil {
|
||||
return
|
||||
// Release the hole punch reference to the shared bind
|
||||
if s.sharedBind != nil {
|
||||
// Release hole punch reference (WireGuard already released its reference via device.Close())
|
||||
logger.Debug("Releasing shared bind (refcount before release: %d)", s.sharedBind.GetRefCount())
|
||||
s.sharedBind.Release()
|
||||
s.sharedBind = nil
|
||||
logger.Info("Released shared UDP bind")
|
||||
}
|
||||
|
||||
s.serverPubKey = serverPubKey
|
||||
s.holePunchEndpoint = endpoint
|
||||
|
||||
logger.Debug("Starting UDP hole punch to %s", s.holePunchEndpoint)
|
||||
|
||||
// Create a new stop channel for this holepunch session
|
||||
s.stopHolepunch = make(chan struct{})
|
||||
|
||||
// start the UDP holepunch
|
||||
go s.keepSendingUDPHolePunch(s.holePunchEndpoint)
|
||||
}
|
||||
|
||||
func (s *WireGuardService) SetToken(token string) {
|
||||
s.token = token
|
||||
if s.holePunchManager != nil {
|
||||
s.holePunchManager.SetToken(token)
|
||||
}
|
||||
}
|
||||
|
||||
// GetNetstackNet returns the netstack network interface for use by other components
|
||||
@@ -412,6 +417,19 @@ func (s *WireGuardService) SetOnNetstackClose(callback func()) {
|
||||
s.onNetstackClose = callback
|
||||
}
|
||||
|
||||
// StartHolepunch starts hole punching to a specific endpoint
|
||||
func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string) {
|
||||
if s.holePunchManager == nil {
|
||||
logger.Warn("Hole punch manager not initialized")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Starting hole punch to %s with public key: %s", endpoint, publicKey)
|
||||
if err := s.holePunchManager.StartSingleEndpoint(endpoint, publicKey); err != nil {
|
||||
logger.Warn("Failed to start hole punch: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) LoadRemoteConfig() error {
|
||||
if s.stopGetConfig != nil {
|
||||
s.stopGetConfig()
|
||||
@@ -485,10 +503,9 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
// Parse the IP address and CIDR mask
|
||||
tunnelIP := netip.MustParseAddr(parts[0])
|
||||
|
||||
// stop the holepunch its a channel
|
||||
if s.stopHolepunch != nil {
|
||||
close(s.stopHolepunch)
|
||||
s.stopHolepunch = nil
|
||||
// Stop any ongoing hole punch operations
|
||||
if s.holePunchManager != nil {
|
||||
s.holePunchManager.Stop()
|
||||
}
|
||||
|
||||
// Parse the IP address from the config
|
||||
@@ -512,8 +529,8 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
// s.proxyManager.SetTNet(s.tnet)
|
||||
s.TunnelIP = tunnelIP.String()
|
||||
|
||||
// Create WireGuard device
|
||||
s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger(
|
||||
// Create WireGuard device using the shared bind
|
||||
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
|
||||
device.LogLevelSilent, // Use silent logging by default - could be made configurable
|
||||
"wireguard: ",
|
||||
))
|
||||
@@ -946,171 +963,6 @@ func (s *WireGuardService) reportPeerBandwidth() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error {
|
||||
|
||||
if s.serverPubKey == "" || s.token == "" {
|
||||
logger.Debug("Server public key or token not set, skipping UDP hole punch")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse server address
|
||||
serverSplit := strings.Split(serverAddr, ":")
|
||||
if len(serverSplit) < 2 {
|
||||
return fmt.Errorf("invalid server address format, expected hostname:port")
|
||||
}
|
||||
|
||||
serverHostname := serverSplit[0]
|
||||
serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse server port: %v", err)
|
||||
}
|
||||
|
||||
// Resolve server hostname to IP
|
||||
serverIPAddr := network.HostToAddr(serverHostname)
|
||||
if serverIPAddr == nil {
|
||||
return fmt.Errorf("failed to resolve server hostname")
|
||||
}
|
||||
|
||||
// Create local UDP address using the same port as WireGuard
|
||||
localAddr := &net.UDPAddr{
|
||||
IP: net.IPv4zero,
|
||||
Port: int(s.Port),
|
||||
}
|
||||
|
||||
// Create remote server address
|
||||
remoteAddr := &net.UDPAddr{
|
||||
IP: serverIPAddr.IP,
|
||||
Port: int(serverPort),
|
||||
}
|
||||
|
||||
// Create UDP connection bound to the same port as WireGuard
|
||||
conn, err := net.DialUDP("udp", localAddr, remoteAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create netstack UDP connection: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Create JSON payload
|
||||
payload := struct {
|
||||
NewtID string `json:"newtId"`
|
||||
Token string `json:"token"`
|
||||
}{
|
||||
NewtID: s.newtId,
|
||||
Token: s.token,
|
||||
}
|
||||
|
||||
// Convert payload to JSON
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload using the server's WireGuard public key
|
||||
encryptedPayload, err := s.encryptPayload(payloadBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt payload: %v", err)
|
||||
}
|
||||
|
||||
// Convert encrypted payload to JSON
|
||||
jsonData, err := json.Marshal(encryptedPayload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal encrypted payload: %v", err)
|
||||
}
|
||||
|
||||
// Send the encrypted packet using the netstack UDP connection
|
||||
_, err = conn.Write(jsonData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send UDP packet: %v", err)
|
||||
}
|
||||
|
||||
logger.Debug("Sent UDP hole punch to %s via netstack", remoteAddr.String())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) {
|
||||
// Generate an ephemeral keypair for this message
|
||||
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
|
||||
}
|
||||
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
|
||||
|
||||
// Parse the server's public key
|
||||
serverPubKey, err := wgtypes.ParseKey(s.serverPubKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse server public key: %v", err)
|
||||
}
|
||||
|
||||
// Use X25519 for key exchange (replacing deprecated ScalarMult)
|
||||
var ephPrivKeyFixed [32]byte
|
||||
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
|
||||
|
||||
// Perform X25519 key exchange
|
||||
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
||||
}
|
||||
|
||||
// Create an AEAD cipher using the shared secret
|
||||
aead, err := chacha20poly1305.New(sharedSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload
|
||||
ciphertext := aead.Seal(nil, nonce, payload, nil)
|
||||
|
||||
// Prepare the final encrypted message
|
||||
encryptedMsg := struct {
|
||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
}{
|
||||
EphemeralPublicKey: ephemeralPublicKey.String(),
|
||||
Nonce: nonce,
|
||||
Ciphertext: ciphertext,
|
||||
}
|
||||
|
||||
return encryptedMsg, nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) keepSendingUDPHolePunch(host string) {
|
||||
logger.Info("Starting UDP hole punch routine to %s:21820", host)
|
||||
|
||||
// send initial hole punch
|
||||
if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
|
||||
logger.Debug("Failed to send initial UDP hole punch: %v", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(3 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.NewTimer(15 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopHolepunch:
|
||||
logger.Info("Stopping UDP holepunch")
|
||||
return
|
||||
case <-timeout.C:
|
||||
logger.Info("UDP holepunch routine timed out after 15 seconds")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
|
||||
logger.Debug("Failed to send UDP hole punch: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
|
||||
var replace = false
|
||||
for _, t := range targetData.Targets {
|
||||
@@ -1242,8 +1094,8 @@ func (s *WireGuardService) ReplaceNetstack() error {
|
||||
s.tun = newTun
|
||||
s.tnet = newTnet
|
||||
|
||||
// Create new WireGuard device with same port
|
||||
s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger(
|
||||
// Create new WireGuard device with same shared bind
|
||||
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
|
||||
device.LogLevelSilent,
|
||||
"wireguard: ",
|
||||
))
|
||||
|
||||
Reference in New Issue
Block a user