Remove proxy manager and break out subnet proxy

This commit is contained in:
Owen
2025-11-15 21:46:32 -05:00
parent f49a276259
commit 491180c6a1
3 changed files with 372 additions and 533 deletions

321
netstack2/proxy.go Normal file
View File

@@ -0,0 +1,321 @@
package netstack2
import (
"fmt"
"net/netip"
"sync"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
// PortRange represents an allowed range of ports (inclusive)
type PortRange struct {
Min uint16
Max uint16
}
// SubnetRule represents a subnet with optional port restrictions
type SubnetRule struct {
Prefix netip.Prefix
PortRanges []PortRange // empty slice means all ports allowed
}
// SubnetLookup provides fast IP subnet and port matching
type SubnetLookup struct {
mu sync.RWMutex
rules []SubnetRule
}
// NewSubnetLookup creates a new subnet lookup table
func NewSubnetLookup() *SubnetLookup {
return &SubnetLookup{
rules: make([]SubnetRule, 0),
}
}
// AddSubnet adds a subnet to the lookup table with optional port restrictions
// If portRanges is nil or empty, all ports are allowed for this subnet
func (sl *SubnetLookup) AddSubnet(prefix netip.Prefix, portRanges []PortRange) {
sl.mu.Lock()
defer sl.mu.Unlock()
sl.rules = append(sl.rules, SubnetRule{
Prefix: prefix,
PortRanges: portRanges,
})
}
// RemoveSubnet removes a subnet from the lookup table
func (sl *SubnetLookup) RemoveSubnet(prefix netip.Prefix) {
sl.mu.Lock()
defer sl.mu.Unlock()
for i, rule := range sl.rules {
if rule.Prefix == prefix {
sl.rules = append(sl.rules[:i], sl.rules[i+1:]...)
return
}
}
}
// Match checks if an IP and port match any subnet rule
// Returns true if the IP is in a matching subnet AND the port is in an allowed range
func (sl *SubnetLookup) Match(ip netip.Addr, port uint16) bool {
sl.mu.RLock()
defer sl.mu.RUnlock()
for _, rule := range sl.rules {
if rule.Prefix.Contains(ip) {
// If no port ranges specified, all ports are allowed
if len(rule.PortRanges) == 0 {
return true
}
// Check if port is in any of the allowed ranges
for _, pr := range rule.PortRanges {
if port >= pr.Min && port <= pr.Max {
return true
}
}
}
}
return false
}
// ProxyHandler handles packet injection and extraction for promiscuous mode
type ProxyHandler struct {
proxyStack *stack.Stack
proxyEp *channel.Endpoint
proxyNotifyHandle *channel.NotificationHandle
tcpHandler *TCPHandler
udpHandler *UDPHandler
subnetLookup *SubnetLookup
enabled bool
}
// ProxyHandlerOptions configures the proxy handler
type ProxyHandlerOptions struct {
EnableTCP bool
EnableUDP bool
MTU int
}
// NewProxyHandler creates a new proxy handler for promiscuous mode
func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
if !options.EnableTCP && !options.EnableUDP {
return nil, nil // No proxy needed
}
handler := &ProxyHandler{
enabled: true,
subnetLookup: NewSubnetLookup(),
proxyEp: channel.New(1024, uint32(options.MTU), ""),
proxyStack: stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
icmp.NewProtocol6,
},
}),
}
// Initialize TCP handler if enabled
if options.EnableTCP {
handler.tcpHandler = NewTCPHandler(handler.proxyStack)
if err := handler.tcpHandler.InstallTCPHandler(); err != nil {
return nil, fmt.Errorf("failed to install TCP handler: %v", err)
}
}
// Initialize UDP handler if enabled
if options.EnableUDP {
handler.udpHandler = NewUDPHandler(handler.proxyStack)
if err := handler.udpHandler.InstallUDPHandler(); err != nil {
return nil, fmt.Errorf("failed to install UDP handler: %v", err)
}
}
// Example 1: Add a subnet with no port restrictions (all ports allowed)
// This accepts all traffic to 10.20.20.0/24
subnet1 := netip.MustParsePrefix("10.20.20.0/24")
handler.AddSubnetRule(subnet1, nil)
// Example 2: Add a subnet with specific port ranges
// This accepts traffic to 192.168.1.0/24 only on ports 80, 443, and 8000-9000
subnet2 := netip.MustParsePrefix("10.20.21.21/32")
handler.AddSubnetRule(subnet2, []PortRange{
{Min: 12000, Max: 12001},
{Min: 8000, Max: 8000},
})
return handler, nil
}
// AddSubnetRule adds a subnet with optional port restrictions to the proxy handler
// If portRanges is nil or empty, all ports are allowed for this subnet
func (p *ProxyHandler) AddSubnetRule(prefix netip.Prefix, portRanges []PortRange) {
if p == nil || !p.enabled {
return
}
p.subnetLookup.AddSubnet(prefix, portRanges)
}
// RemoveSubnetRule removes a subnet from the proxy handler
func (p *ProxyHandler) RemoveSubnetRule(prefix netip.Prefix) {
if p == nil || !p.enabled {
return
}
p.subnetLookup.RemoveSubnet(prefix)
}
// Initialize sets up the promiscuous NIC with the netTun's notification system
func (p *ProxyHandler) Initialize(notifiable channel.Notification) error {
if p == nil || !p.enabled {
return nil
}
// Add notification handler
p.proxyNotifyHandle = p.proxyEp.AddNotify(notifiable)
// Create NIC with promiscuous mode
tcpipErr := p.proxyStack.CreateNICWithOptions(1, p.proxyEp, stack.NICOptions{
Disabled: false,
QDisc: nil,
})
if tcpipErr != nil {
return fmt.Errorf("CreateNIC (proxy): %v", tcpipErr)
}
// Enable promiscuous mode - accepts packets for any destination IP
if tcpipErr := p.proxyStack.SetPromiscuousMode(1, true); tcpipErr != nil {
return fmt.Errorf("SetPromiscuousMode: %v", tcpipErr)
}
// Enable spoofing - allows sending packets from any source IP
if tcpipErr := p.proxyStack.SetSpoofing(1, true); tcpipErr != nil {
return fmt.Errorf("SetSpoofing: %v", tcpipErr)
}
// Add default route
p.proxyStack.AddRoute(tcpip.Route{
Destination: header.IPv4EmptySubnet,
NIC: 1,
})
return nil
}
// HandleIncomingPacket processes incoming packets and determines if they should
// be injected into the proxy stack
func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
if p == nil || !p.enabled {
return false
}
// Check minimum packet size
if len(packet) < header.IPv4MinimumSize {
return false
}
// Only handle IPv4 for now
if packet[0]>>4 != 4 {
return false
}
// Parse IPv4 header
ipv4Header := header.IPv4(packet)
dstIP := ipv4Header.DestinationAddress()
// Convert gvisor tcpip.Address to netip.Addr
dstBytes := dstIP.As4()
addr := netip.AddrFrom4(dstBytes)
// Parse transport layer to get destination port
var dstPort uint16
protocol := ipv4Header.TransportProtocol()
headerLen := int(ipv4Header.HeaderLength())
// Extract port based on protocol
switch protocol {
case header.TCPProtocolNumber:
if len(packet) < headerLen+header.TCPMinimumSize {
return false
}
tcpHeader := header.TCP(packet[headerLen:])
dstPort = tcpHeader.DestinationPort()
case header.UDPProtocolNumber:
if len(packet) < headerLen+header.UDPMinimumSize {
return false
}
udpHeader := header.UDP(packet[headerLen:])
dstPort = udpHeader.DestinationPort()
default:
// For other protocols (ICMP, etc.), use port 0 (must match rules with no port restrictions)
dstPort = 0
}
// Check if the destination IP and port match any subnet rule
if p.subnetLookup.Match(addr, dstPort) {
// Inject into proxy stack
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb)
return true
}
return false
}
// ReadOutgoingPacket reads packets from the proxy stack that need to be
// sent back through the tunnel
func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
if p == nil || !p.enabled {
return nil
}
pkt := p.proxyEp.Read()
if pkt != nil {
view := pkt.ToView()
pkt.DecRef()
return view
}
return nil
}
// Close cleans up the proxy handler resources
func (p *ProxyHandler) Close() error {
if p == nil || !p.enabled {
return nil
}
if p.proxyStack != nil {
p.proxyStack.RemoveNIC(1)
p.proxyStack.Close()
}
if p.proxyEp != nil {
if p.proxyNotifyHandle != nil {
p.proxyEp.RemoveNotify(p.proxyNotifyHandle)
}
p.proxyEp.Close()
}
return nil
}

View File

@@ -22,7 +22,6 @@ import (
"syscall"
"time"
"github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/tun"
"golang.org/x/net/dns/dnsmessage"
@@ -41,19 +40,15 @@ import (
)
type netTun struct {
ep *channel.Endpoint
proxyEp *channel.Endpoint // Separate endpoint for promiscuous mode
stack *stack.Stack
proxyStack *stack.Stack // Separate stack for proxy endpoint
events chan tun.Event
notifyHandle *channel.NotificationHandle
proxyNotifyHandle *channel.NotificationHandle // Notify handle for proxy endpoint
incomingPacket chan *buffer.View
mtu int
dnsServers []netip.Addr
hasV4, hasV6 bool
tcpHandler *TCPHandler
udpHandler *UDPHandler
ep *channel.Endpoint
stack *stack.Stack
events chan tun.Event
notifyHandle *channel.NotificationHandle
incomingPacket chan *buffer.View
mtu int
dnsServers []netip.Addr
hasV4, hasV6 bool
proxyHandler *ProxyHandler // Handles promiscuous mode packet processing
}
type Net netTun
@@ -80,27 +75,27 @@ func CreateNetTUNWithOptions(localAddresses, dnsServers []netip.Addr, mtu int, o
HandleLocal: true,
}
dev := &netTun{
ep: channel.New(1024, uint32(mtu), ""),
proxyEp: channel.New(1024, uint32(mtu), ""),
stack: stack.New(stackOpts),
proxyStack: stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
icmp.NewProtocol6,
},
}),
ep: channel.New(1024, uint32(mtu), ""),
stack: stack.New(stackOpts),
events: make(chan tun.Event, 10),
incomingPacket: make(chan *buffer.View),
dnsServers: dnsServers,
mtu: mtu,
}
// Initialize proxy handler if TCP or UDP proxying is enabled
if options.EnableTCPProxy || options.EnableUDPProxy {
var err error
dev.proxyHandler, err = NewProxyHandler(ProxyHandlerOptions{
EnableTCP: options.EnableTCPProxy,
EnableUDP: options.EnableUDPProxy,
MTU: mtu,
})
if err != nil {
return nil, nil, fmt.Errorf("failed to create proxy handler: %v", err)
}
}
sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is enabled by default
tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
if tcpipErr != nil {
@@ -113,6 +108,13 @@ func CreateNetTUNWithOptions(localAddresses, dnsServers []netip.Addr, mtu int, o
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
// Initialize proxy handler after main stack is set up
if dev.proxyHandler != nil {
if err := dev.proxyHandler.Initialize(dev); err != nil {
return nil, nil, err
}
}
if err := dev.stack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
return nil, nil, fmt.Errorf("set ipv4 forwarding: %s", err)
}
@@ -145,111 +147,6 @@ func CreateNetTUNWithOptions(localAddresses, dnsServers []netip.Addr, mtu int, o
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
}
// Add specific route for proxy network (10.20.20.0/24) to NIC 2
if options.EnableTCPProxy || options.EnableUDPProxy {
if options.EnableTCPProxy {
dev.tcpHandler = NewTCPHandler(dev.proxyStack)
if err := dev.tcpHandler.InstallTCPHandler(); err != nil {
return nil, nil, fmt.Errorf("failed to install TCP handler: %v", err)
}
}
if options.EnableUDPProxy {
dev.udpHandler = NewUDPHandler(dev.proxyStack)
if err := dev.udpHandler.InstallUDPHandler(); err != nil {
return nil, nil, fmt.Errorf("failed to install UDP handler: %v", err)
}
}
dev.proxyNotifyHandle = dev.proxyEp.AddNotify(dev)
tcpipErr = dev.proxyStack.CreateNICWithOptions(1, dev.proxyEp, stack.NICOptions{
Disabled: false,
// If no queueing discipline was specified
// provide a stub implementation that just
// delegates to the lower link endpoint.
QDisc: nil,
})
if tcpipErr != nil {
return nil, nil, fmt.Errorf("CreateNIC 2 (proxy): %v", tcpipErr)
}
// Enable promiscuous mode ONLY on NIC 2
// This allows the NIC to accept packets destined for any IP address
if tcpipErr := dev.proxyStack.SetPromiscuousMode(1, true); tcpipErr != nil {
return nil, nil, fmt.Errorf("SetPromiscuousMode on NIC 2: %v", tcpipErr)
}
// Enable spoofing ONLY on NIC 2
// This allows the stack to send packets from any address, not just owned addresses
if tcpipErr := dev.proxyStack.SetSpoofing(1, true); tcpipErr != nil {
return nil, nil, fmt.Errorf("SetSpoofing on NIC 2: %v", tcpipErr)
}
// // Add a wildcard IPv4 address covering the 10.0.0.0/8 space so the stack can
// // synthesize temporary endpoints for any 10.x.y.z destination. This mimics
// // the tun2socks behaviour and is required once promiscuous+spoofing are turned on.
// wildcardAddr := tcpip.ProtocolAddress{
// Protocol: ipv4.ProtocolNumber,
// AddressWithPrefix: tcpip.AddressWithPrefix{
// Address: tcpip.AddrFrom4([4]byte{10, 0, 0, 1}),
// PrefixLen: 8,
// },
// }
// if tcpipErr = dev.stack.AddProtocolAddress(2, wildcardAddr, stack.AddressProperties{
// PEB: stack.CanBePrimaryEndpoint,
// }); tcpipErr != nil {
// return nil, nil, fmt.Errorf("Add wildcard proxy address: %v", tcpipErr)
// }
// // Keep the real service IP (10.20.20.1/24) so existing clients that target the
// // gateway explicitly still resolve as before.
// proxyAddr := netip.MustParseAddr("10.20.20.1")
// protoAddr := tcpip.ProtocolAddress{
// Protocol: ipv4.ProtocolNumber,
// AddressWithPrefix: tcpip.AddressWithPrefix{
// Address: tcpip.AddrFromSlice(proxyAddr.AsSlice()),
// PrefixLen: 24,
// },
// }
// if tcpipErr = dev.stack.AddProtocolAddress(2, protoAddr, stack.AddressProperties{}); tcpipErr != nil {
// return nil, nil, fmt.Errorf("Add proxy service address: %v", tcpipErr)
// }
// proxySubnet := netip.MustParsePrefix("10.20.20.0/24")
// proxyTcpipSubnet, err := tcpip.NewSubnet(
// tcpip.AddrFromSlice(proxySubnet.Addr().AsSlice()),
// tcpip.MaskFromBytes(net.CIDRMask(24, 32)),
// )
// if err != nil {
// return nil, nil, fmt.Errorf("failed to create proxy subnet: %v", err)
// }
dev.proxyStack.AddRoute(tcpip.Route{
Destination: header.IPv4EmptySubnet,
NIC: 1,
})
}
// print the stack routes table and interfaces for debugging
logger.Info("Stack configuration:")
// // Print NICs
// nics := dev.stack.NICInfo()
// for nicID, nicInfo := range nics {
// logger.Info("NIC %d: %s (MTU: %d)", nicID, nicInfo.Name, nicInfo.MTU)
// for _, addr := range nicInfo.ProtocolAddresses {
// logger.Info(" Address: %s", addr.AddressWithPrefix)
// }
// }
// // Print routing table
// routes := dev.stack.GetRouteTable()
// logger.Info("Routing table (%d routes):", len(routes))
// for i, route := range routes {
// logger.Info(" Route %d: %s -> NIC %d", i, route.Destination, route.NIC)
// }
dev.events <- tun.EventUp
return dev, (*Net)(dev), nil
}
@@ -287,32 +184,20 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
continue
}
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
// Try to handle packet via proxy handler first
if tun.proxyHandler != nil && tun.proxyHandler.HandleIncomingPacket(packet) {
// Packet was handled by proxy
continue
}
// Determine which NIC to inject the packet into based on destination IP
targetEp := tun.ep // Default to NIC 1
// Default handling: inject into main stack
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
switch packet[0] >> 4 {
case 4:
// // Parse IPv4 header to check destination
if len(packet) >= header.IPv4MinimumSize {
ipv4Header := header.IPv4(packet)
dstIP := ipv4Header.DestinationAddress()
// Check if destination is in the proxy range (10.20.20.0/24)
// If so, inject into proxyEp (NIC 2) which has promiscuous mode
if tun.proxyEp != nil {
dstBytes := dstIP.As4()
// Check for 10.20.20.x
if dstBytes[0] == 10 && dstBytes[1] == 20 && dstBytes[2] == 20 {
targetEp = tun.proxyEp
}
}
}
targetEp.InjectInbound(header.IPv4ProtocolNumber, pkb)
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
case 6:
// For IPv6, always use NIC 1 for now
targetEp.InjectInbound(header.IPv6ProtocolNumber, pkb)
tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
default:
return 0, syscall.EAFNOSUPPORT
}
@@ -320,86 +205,6 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
return len(buf), nil
}
// logPacketDetails parses and logs packet information
func logPacketDetails(pkt *stack.PacketBuffer, nicID int) {
netProto := pkt.NetworkProtocolNumber
var srcIP, dstIP string
var protocol string
var srcPort, dstPort uint16
// Parse network layer
switch netProto {
case header.IPv4ProtocolNumber:
if pkt.NetworkHeader().View().Size() >= header.IPv4MinimumSize {
ipv4 := header.IPv4(pkt.NetworkHeader().Slice())
srcIP = ipv4.SourceAddress().String()
dstIP = ipv4.DestinationAddress().String()
// Parse transport layer
switch ipv4.Protocol() {
case uint8(header.TCPProtocolNumber):
protocol = "TCP"
if pkt.TransportHeader().View().Size() >= header.TCPMinimumSize {
tcp := header.TCP(pkt.TransportHeader().Slice())
srcPort = tcp.SourcePort()
dstPort = tcp.DestinationPort()
}
case uint8(header.UDPProtocolNumber):
protocol = "UDP"
if pkt.TransportHeader().View().Size() >= header.UDPMinimumSize {
udp := header.UDP(pkt.TransportHeader().Slice())
srcPort = udp.SourcePort()
dstPort = udp.DestinationPort()
}
case uint8(header.ICMPv4ProtocolNumber):
protocol = "ICMPv4"
default:
protocol = fmt.Sprintf("Proto-%d", ipv4.Protocol())
}
}
case header.IPv6ProtocolNumber:
if pkt.NetworkHeader().View().Size() >= header.IPv6MinimumSize {
ipv6 := header.IPv6(pkt.NetworkHeader().Slice())
srcIP = ipv6.SourceAddress().String()
dstIP = ipv6.DestinationAddress().String()
// Parse transport layer
switch ipv6.TransportProtocol() {
case header.TCPProtocolNumber:
protocol = "TCP"
if pkt.TransportHeader().View().Size() >= header.TCPMinimumSize {
tcp := header.TCP(pkt.TransportHeader().Slice())
srcPort = tcp.SourcePort()
dstPort = tcp.DestinationPort()
}
case header.UDPProtocolNumber:
protocol = "UDP"
if pkt.TransportHeader().View().Size() >= header.UDPMinimumSize {
udp := header.UDP(pkt.TransportHeader().Slice())
srcPort = udp.SourcePort()
dstPort = udp.DestinationPort()
}
case header.ICMPv6ProtocolNumber:
protocol = "ICMPv6"
default:
protocol = fmt.Sprintf("Proto-%d", ipv6.TransportProtocol())
}
}
default:
protocol = fmt.Sprintf("Unknown-NetProto-%d", netProto)
}
packetSize := pkt.Size()
if srcPort > 0 && dstPort > 0 {
logger.Info("NIC %d packet: %s %s:%d -> %s:%d (size: %d bytes)",
nicID, protocol, srcIP, srcPort, dstIP, dstPort, packetSize)
} else {
logger.Info("NIC %d packet: %s %s -> %s (size: %d bytes)",
nicID, protocol, srcIP, dstIP, packetSize)
}
}
func (tun *netTun) WriteNotify() {
// Handle notifications from main endpoint (NIC 1)
pkt := tun.ep.Read()
@@ -410,13 +215,11 @@ func (tun *netTun) WriteNotify() {
return
}
// Handle notifications from proxy endpoint (NIC 2) if it exists
// Handle notifications from proxy handler if it exists
// These are response packets from the proxied connections that need to go back to WireGuard
if tun.proxyEp != nil {
pkt = tun.proxyEp.Read()
if pkt != nil {
view := pkt.ToView()
pkt.DecRef()
if tun.proxyHandler != nil {
view := tun.proxyHandler.ReadOutgoingPacket()
if view != nil {
tun.incomingPacket <- view
return
}
@@ -426,17 +229,15 @@ func (tun *netTun) WriteNotify() {
func (tun *netTun) Close() error {
tun.stack.RemoveNIC(1)
// Clean up proxy NIC if it exists
if tun.proxyEp != nil {
tun.stack.RemoveNIC(2)
tun.proxyEp.RemoveNotify(tun.proxyNotifyHandle)
tun.proxyEp.Close()
}
tun.stack.Close()
tun.ep.RemoveNotify(tun.notifyHandle)
tun.ep.Close()
// Clean up proxy handler if it exists
if tun.proxyHandler != nil {
tun.proxyHandler.Close()
}
if tun.events != nil {
close(tun.events)
}

View File

@@ -18,7 +18,6 @@ import (
"github.com/fosrl/newt/holepunch"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/netstack2"
"github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/util"
"github.com/fosrl/newt/websocket"
"golang.zx2c4.com/wireguard/device"
@@ -88,34 +87,12 @@ type WireGuardService struct {
onNetstackClose func()
othertnet *netstack.Net
// Proxy manager for tunnel
proxyManager *proxy.ProxyManager
TunnelIP string
TunnelIP string
// Shared bind and holepunch manager
sharedBind *bind.SharedBind
holePunchManager *holepunch.Manager
}
// GetProxyManager returns the proxy manager for this WireGuardService
func (s *WireGuardService) GetProxyManager() *proxy.ProxyManager {
return s.proxyManager
}
// AddProxyTarget adds a target to the proxy manager
func (s *WireGuardService) AddProxyTarget(proto, listenIP string, port int, targetAddr string) error {
if s.proxyManager == nil {
return fmt.Errorf("proxy manager not initialized")
}
return s.proxyManager.AddTarget(proto, listenIP, port, targetAddr)
}
// RemoveProxyTarget removes a target from the proxy manager
func (s *WireGuardService) RemoveProxyTarget(proto, listenIP string, port int) error {
if s.proxyManager == nil {
return fmt.Errorf("proxy manager not initialized")
}
return s.proxyManager.RemoveTarget(proto, listenIP, port)
}
func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string) (*WireGuardService, error) {
var key wgtypes.Key
var err error
@@ -189,7 +166,6 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
lastReadings: make(map[string]PeerReading),
Port: port,
dns: dnsAddrs,
proxyManager: proxy.NewProxyManagerWithoutTNet(),
sharedBind: sharedBind,
}
@@ -202,10 +178,6 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer)
wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer)
wsClient.RegisterHandler("newt/wg/tcp/add", service.addTcpTarget)
wsClient.RegisterHandler("newt/wg/udp/add", service.addUdpTarget)
wsClient.RegisterHandler("newt/wg/udp/remove", service.removeUdpTarget)
wsClient.RegisterHandler("newt/wg/tcp/remove", service.removeTcpTarget)
return service, nil
}
@@ -218,86 +190,6 @@ func (s *WireGuardService) ReportRTT(seconds float64) {
telemetry.ObserveTunnelLatency(context.Background(), s.serverPubKey, "wireguard", seconds)
}
func (s *WireGuardService) addTcpTarget(msg websocket.WSMessage) {
logger.Debug("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if s.TunnelIP == "" || s.proxyManager == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", targetData)
}
}
func (s *WireGuardService) addUdpTarget(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if s.TunnelIP == "" || s.proxyManager == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", targetData)
}
}
func (s *WireGuardService) removeUdpTarget(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if s.TunnelIP == "" || s.proxyManager == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "udp", targetData)
}
}
func (s *WireGuardService) removeTcpTarget(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if s.TunnelIP == "" || s.proxyManager == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "tcp", targetData)
}
}
func (s *WireGuardService) SetOthertnet(tnet *netstack.Net) {
s.othertnet = tnet
}
@@ -435,18 +327,6 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
if err := s.ensureWireguardPeers(config.Peers); err != nil {
logger.Error("Failed to ensure WireGuard peers: %v", err)
}
// add the targets if there are any
if len(config.Targets.TCP) > 0 {
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", TargetData{Targets: config.Targets.TCP})
}
if len(config.Targets.UDP) > 0 {
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", TargetData{Targets: config.Targets.UDP})
}
// Create ProxyManager for this tunnel
s.proxyManager.Start()
}
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
@@ -484,7 +364,7 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
s.mu.Unlock()
return fmt.Errorf("failed to create TUN device: %v", err)
}
// s.proxyManager.SetTNet(s.tnet)
s.TunnelIP = tunnelIP.String()
// Create WireGuard device using the shared bind
@@ -921,169 +801,6 @@ func (s *WireGuardService) reportPeerBandwidth() error {
return nil
}
func (s *WireGuardService) updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
var replace = false
for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
if len(parts) != 3 {
logger.Info("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil {
logger.Info("Invalid port: %s", parts[0])
continue
}
if action == "add" {
target := parts[1] + ":" + parts[2]
// Call updown script if provided
processedTarget := target
// Only remove the specific target if it exists
err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil {
// Ignore "target not found" errors as this is expected for new targets
if !strings.Contains(err.Error(), "target not found") {
logger.Error("Failed to remove existing target: %v", err)
}
} else {
replace = true // We successfully removed an existing target
}
// Add the new target
pm.AddTarget(proto, tunnelIP, port, processedTarget)
} else if action == "remove" {
logger.Info("Removing target with port %d", port)
err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil {
logger.Error("Failed to remove target: %v", err)
return err
}
}
}
if replace {
// If we replaced any targets, we need to hot swap the netstack
if err := s.ReplaceNetstack(); err != nil {
logger.Error("Failed to replace netstack after updating targets: %v", err)
return err
}
logger.Info("Netstack replaced successfully after updating targets")
} else {
logger.Info("No targets updated, no netstack replacement needed")
}
return nil
}
func parseTargetData(data interface{}) (TargetData, error) {
var targetData TargetData
jsonData, err := json.Marshal(data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return targetData, err
}
if err := json.Unmarshal(jsonData, &targetData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return targetData, err
}
return targetData, nil
}
// Add this method to WireGuardService
func (s *WireGuardService) ReplaceNetstack() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.device == nil || s.tun == nil {
return fmt.Errorf("WireGuard device not initialized")
}
// Parse the current tunnel IP from the existing config
parts := strings.Split(s.config.IpAddress, "/")
if len(parts) != 2 {
return fmt.Errorf("invalid IP address format: %s", s.config.IpAddress)
}
tunnelIP := netip.MustParseAddr(parts[0])
// Stop the proxy manager temporarily
s.proxyManager.Stop()
// Create new TUN device and netstack with new DNS
newTun, newTnet, err := netstack2.CreateNetTUN(
[]netip.Addr{tunnelIP},
s.dns,
s.mtu)
if err != nil {
// Restart proxy manager with old tnet on failure
s.proxyManager.Start()
return fmt.Errorf("failed to create new TUN device: %v", err)
}
// Get current device config before closing
currentConfig, err := s.device.IpcGet()
if err != nil {
newTun.Close()
s.proxyManager.Start()
return fmt.Errorf("failed to get current device config: %v", err)
}
// Filter out read-only fields from the config
filteredConfig := s.filterReadOnlyFields(currentConfig)
// if onNetstackClose callback is set, call it
if s.onNetstackClose != nil {
s.onNetstackClose()
}
// Close old device (this closes the old TUN device)
s.device.Close()
// Update references
s.tun = newTun
s.tnet = newTnet
// Create new WireGuard device with same shared bind
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
device.LogLevelSilent,
"wireguard: ",
))
// Restore the configuration (without read-only fields)
err = s.device.IpcSet(filteredConfig)
if err != nil {
return fmt.Errorf("failed to restore WireGuard configuration: %v", err)
}
// Bring up the device
err = s.device.Up()
if err != nil {
return fmt.Errorf("failed to bring up new WireGuard device: %v", err)
}
// Update proxy manager with new tnet and restart
// s.proxyManager.SetTNet(s.tnet)
s.proxyManager.Start()
s.proxyManager.PrintTargets()
// Call the netstack ready callback if set
if s.onNetstackReady != nil {
go s.onNetstackReady(s.tnet)
}
return nil
}
// filterReadOnlyFields removes read-only fields from WireGuard IPC configuration
func (s *WireGuardService) filterReadOnlyFields(config string) string {
lines := strings.Split(config, "\n")