mirror of
https://github.com/fosrl/newt.git
synced 2026-03-02 00:36:42 +00:00
Remove proxy manager and break out subnet proxy
This commit is contained in:
321
netstack2/proxy.go
Normal file
321
netstack2/proxy.go
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
package netstack2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PortRange represents an allowed range of ports (inclusive)
|
||||||
|
type PortRange struct {
|
||||||
|
Min uint16
|
||||||
|
Max uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubnetRule represents a subnet with optional port restrictions
|
||||||
|
type SubnetRule struct {
|
||||||
|
Prefix netip.Prefix
|
||||||
|
PortRanges []PortRange // empty slice means all ports allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubnetLookup provides fast IP subnet and port matching
|
||||||
|
type SubnetLookup struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
rules []SubnetRule
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSubnetLookup creates a new subnet lookup table
|
||||||
|
func NewSubnetLookup() *SubnetLookup {
|
||||||
|
return &SubnetLookup{
|
||||||
|
rules: make([]SubnetRule, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSubnet adds a subnet to the lookup table with optional port restrictions
|
||||||
|
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||||
|
func (sl *SubnetLookup) AddSubnet(prefix netip.Prefix, portRanges []PortRange) {
|
||||||
|
sl.mu.Lock()
|
||||||
|
defer sl.mu.Unlock()
|
||||||
|
|
||||||
|
sl.rules = append(sl.rules, SubnetRule{
|
||||||
|
Prefix: prefix,
|
||||||
|
PortRanges: portRanges,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveSubnet removes a subnet from the lookup table
|
||||||
|
func (sl *SubnetLookup) RemoveSubnet(prefix netip.Prefix) {
|
||||||
|
sl.mu.Lock()
|
||||||
|
defer sl.mu.Unlock()
|
||||||
|
|
||||||
|
for i, rule := range sl.rules {
|
||||||
|
if rule.Prefix == prefix {
|
||||||
|
sl.rules = append(sl.rules[:i], sl.rules[i+1:]...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match checks if an IP and port match any subnet rule
|
||||||
|
// Returns true if the IP is in a matching subnet AND the port is in an allowed range
|
||||||
|
func (sl *SubnetLookup) Match(ip netip.Addr, port uint16) bool {
|
||||||
|
sl.mu.RLock()
|
||||||
|
defer sl.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, rule := range sl.rules {
|
||||||
|
if rule.Prefix.Contains(ip) {
|
||||||
|
// If no port ranges specified, all ports are allowed
|
||||||
|
if len(rule.PortRanges) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if port is in any of the allowed ranges
|
||||||
|
for _, pr := range rule.PortRanges {
|
||||||
|
if port >= pr.Min && port <= pr.Max {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyHandler handles packet injection and extraction for promiscuous mode
|
||||||
|
type ProxyHandler struct {
|
||||||
|
proxyStack *stack.Stack
|
||||||
|
proxyEp *channel.Endpoint
|
||||||
|
proxyNotifyHandle *channel.NotificationHandle
|
||||||
|
tcpHandler *TCPHandler
|
||||||
|
udpHandler *UDPHandler
|
||||||
|
subnetLookup *SubnetLookup
|
||||||
|
enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyHandlerOptions configures the proxy handler
|
||||||
|
type ProxyHandlerOptions struct {
|
||||||
|
EnableTCP bool
|
||||||
|
EnableUDP bool
|
||||||
|
MTU int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProxyHandler creates a new proxy handler for promiscuous mode
|
||||||
|
func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
|
||||||
|
if !options.EnableTCP && !options.EnableUDP {
|
||||||
|
return nil, nil // No proxy needed
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := &ProxyHandler{
|
||||||
|
enabled: true,
|
||||||
|
subnetLookup: NewSubnetLookup(),
|
||||||
|
proxyEp: channel.New(1024, uint32(options.MTU), ""),
|
||||||
|
proxyStack: stack.New(stack.Options{
|
||||||
|
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||||
|
ipv4.NewProtocol,
|
||||||
|
ipv6.NewProtocol,
|
||||||
|
},
|
||||||
|
TransportProtocols: []stack.TransportProtocolFactory{
|
||||||
|
tcp.NewProtocol,
|
||||||
|
udp.NewProtocol,
|
||||||
|
icmp.NewProtocol4,
|
||||||
|
icmp.NewProtocol6,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize TCP handler if enabled
|
||||||
|
if options.EnableTCP {
|
||||||
|
handler.tcpHandler = NewTCPHandler(handler.proxyStack)
|
||||||
|
if err := handler.tcpHandler.InstallTCPHandler(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to install TCP handler: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize UDP handler if enabled
|
||||||
|
if options.EnableUDP {
|
||||||
|
handler.udpHandler = NewUDPHandler(handler.proxyStack)
|
||||||
|
if err := handler.udpHandler.InstallUDPHandler(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to install UDP handler: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example 1: Add a subnet with no port restrictions (all ports allowed)
|
||||||
|
// This accepts all traffic to 10.20.20.0/24
|
||||||
|
subnet1 := netip.MustParsePrefix("10.20.20.0/24")
|
||||||
|
handler.AddSubnetRule(subnet1, nil)
|
||||||
|
|
||||||
|
// Example 2: Add a subnet with specific port ranges
|
||||||
|
// This accepts traffic to 192.168.1.0/24 only on ports 80, 443, and 8000-9000
|
||||||
|
subnet2 := netip.MustParsePrefix("10.20.21.21/32")
|
||||||
|
handler.AddSubnetRule(subnet2, []PortRange{
|
||||||
|
{Min: 12000, Max: 12001},
|
||||||
|
{Min: 8000, Max: 8000},
|
||||||
|
})
|
||||||
|
|
||||||
|
return handler, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSubnetRule adds a subnet with optional port restrictions to the proxy handler
|
||||||
|
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||||
|
func (p *ProxyHandler) AddSubnetRule(prefix netip.Prefix, portRanges []PortRange) {
|
||||||
|
if p == nil || !p.enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.subnetLookup.AddSubnet(prefix, portRanges)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveSubnetRule removes a subnet from the proxy handler
|
||||||
|
func (p *ProxyHandler) RemoveSubnetRule(prefix netip.Prefix) {
|
||||||
|
if p == nil || !p.enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.subnetLookup.RemoveSubnet(prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize sets up the promiscuous NIC with the netTun's notification system
|
||||||
|
func (p *ProxyHandler) Initialize(notifiable channel.Notification) error {
|
||||||
|
if p == nil || !p.enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add notification handler
|
||||||
|
p.proxyNotifyHandle = p.proxyEp.AddNotify(notifiable)
|
||||||
|
|
||||||
|
// Create NIC with promiscuous mode
|
||||||
|
tcpipErr := p.proxyStack.CreateNICWithOptions(1, p.proxyEp, stack.NICOptions{
|
||||||
|
Disabled: false,
|
||||||
|
QDisc: nil,
|
||||||
|
})
|
||||||
|
if tcpipErr != nil {
|
||||||
|
return fmt.Errorf("CreateNIC (proxy): %v", tcpipErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable promiscuous mode - accepts packets for any destination IP
|
||||||
|
if tcpipErr := p.proxyStack.SetPromiscuousMode(1, true); tcpipErr != nil {
|
||||||
|
return fmt.Errorf("SetPromiscuousMode: %v", tcpipErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable spoofing - allows sending packets from any source IP
|
||||||
|
if tcpipErr := p.proxyStack.SetSpoofing(1, true); tcpipErr != nil {
|
||||||
|
return fmt.Errorf("SetSpoofing: %v", tcpipErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add default route
|
||||||
|
p.proxyStack.AddRoute(tcpip.Route{
|
||||||
|
Destination: header.IPv4EmptySubnet,
|
||||||
|
NIC: 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleIncomingPacket processes incoming packets and determines if they should
|
||||||
|
// be injected into the proxy stack
|
||||||
|
func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
|
||||||
|
if p == nil || !p.enabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check minimum packet size
|
||||||
|
if len(packet) < header.IPv4MinimumSize {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only handle IPv4 for now
|
||||||
|
if packet[0]>>4 != 4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse IPv4 header
|
||||||
|
ipv4Header := header.IPv4(packet)
|
||||||
|
dstIP := ipv4Header.DestinationAddress()
|
||||||
|
|
||||||
|
// Convert gvisor tcpip.Address to netip.Addr
|
||||||
|
dstBytes := dstIP.As4()
|
||||||
|
addr := netip.AddrFrom4(dstBytes)
|
||||||
|
|
||||||
|
// Parse transport layer to get destination port
|
||||||
|
var dstPort uint16
|
||||||
|
protocol := ipv4Header.TransportProtocol()
|
||||||
|
headerLen := int(ipv4Header.HeaderLength())
|
||||||
|
|
||||||
|
// Extract port based on protocol
|
||||||
|
switch protocol {
|
||||||
|
case header.TCPProtocolNumber:
|
||||||
|
if len(packet) < headerLen+header.TCPMinimumSize {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
tcpHeader := header.TCP(packet[headerLen:])
|
||||||
|
dstPort = tcpHeader.DestinationPort()
|
||||||
|
case header.UDPProtocolNumber:
|
||||||
|
if len(packet) < headerLen+header.UDPMinimumSize {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
udpHeader := header.UDP(packet[headerLen:])
|
||||||
|
dstPort = udpHeader.DestinationPort()
|
||||||
|
default:
|
||||||
|
// For other protocols (ICMP, etc.), use port 0 (must match rules with no port restrictions)
|
||||||
|
dstPort = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the destination IP and port match any subnet rule
|
||||||
|
if p.subnetLookup.Match(addr, dstPort) {
|
||||||
|
// Inject into proxy stack
|
||||||
|
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
|
Payload: buffer.MakeWithData(packet),
|
||||||
|
})
|
||||||
|
p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadOutgoingPacket reads packets from the proxy stack that need to be
|
||||||
|
// sent back through the tunnel
|
||||||
|
func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
|
||||||
|
if p == nil || !p.enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt := p.proxyEp.Read()
|
||||||
|
if pkt != nil {
|
||||||
|
view := pkt.ToView()
|
||||||
|
pkt.DecRef()
|
||||||
|
return view
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close cleans up the proxy handler resources
|
||||||
|
func (p *ProxyHandler) Close() error {
|
||||||
|
if p == nil || !p.enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.proxyStack != nil {
|
||||||
|
p.proxyStack.RemoveNIC(1)
|
||||||
|
p.proxyStack.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.proxyEp != nil {
|
||||||
|
if p.proxyNotifyHandle != nil {
|
||||||
|
p.proxyEp.RemoveNotify(p.proxyNotifyHandle)
|
||||||
|
}
|
||||||
|
p.proxyEp.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
297
netstack2/tun.go
297
netstack2/tun.go
@@ -22,7 +22,6 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
|
||||||
"golang.org/x/net/dns/dnsmessage"
|
"golang.org/x/net/dns/dnsmessage"
|
||||||
@@ -41,19 +40,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type netTun struct {
|
type netTun struct {
|
||||||
ep *channel.Endpoint
|
ep *channel.Endpoint
|
||||||
proxyEp *channel.Endpoint // Separate endpoint for promiscuous mode
|
stack *stack.Stack
|
||||||
stack *stack.Stack
|
events chan tun.Event
|
||||||
proxyStack *stack.Stack // Separate stack for proxy endpoint
|
notifyHandle *channel.NotificationHandle
|
||||||
events chan tun.Event
|
incomingPacket chan *buffer.View
|
||||||
notifyHandle *channel.NotificationHandle
|
mtu int
|
||||||
proxyNotifyHandle *channel.NotificationHandle // Notify handle for proxy endpoint
|
dnsServers []netip.Addr
|
||||||
incomingPacket chan *buffer.View
|
hasV4, hasV6 bool
|
||||||
mtu int
|
proxyHandler *ProxyHandler // Handles promiscuous mode packet processing
|
||||||
dnsServers []netip.Addr
|
|
||||||
hasV4, hasV6 bool
|
|
||||||
tcpHandler *TCPHandler
|
|
||||||
udpHandler *UDPHandler
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Net netTun
|
type Net netTun
|
||||||
@@ -80,27 +75,27 @@ func CreateNetTUNWithOptions(localAddresses, dnsServers []netip.Addr, mtu int, o
|
|||||||
HandleLocal: true,
|
HandleLocal: true,
|
||||||
}
|
}
|
||||||
dev := &netTun{
|
dev := &netTun{
|
||||||
ep: channel.New(1024, uint32(mtu), ""),
|
ep: channel.New(1024, uint32(mtu), ""),
|
||||||
proxyEp: channel.New(1024, uint32(mtu), ""),
|
stack: stack.New(stackOpts),
|
||||||
stack: stack.New(stackOpts),
|
|
||||||
proxyStack: stack.New(stack.Options{
|
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
|
||||||
ipv4.NewProtocol,
|
|
||||||
ipv6.NewProtocol,
|
|
||||||
},
|
|
||||||
TransportProtocols: []stack.TransportProtocolFactory{
|
|
||||||
tcp.NewProtocol,
|
|
||||||
udp.NewProtocol,
|
|
||||||
icmp.NewProtocol4,
|
|
||||||
icmp.NewProtocol6,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
events: make(chan tun.Event, 10),
|
events: make(chan tun.Event, 10),
|
||||||
incomingPacket: make(chan *buffer.View),
|
incomingPacket: make(chan *buffer.View),
|
||||||
dnsServers: dnsServers,
|
dnsServers: dnsServers,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize proxy handler if TCP or UDP proxying is enabled
|
||||||
|
if options.EnableTCPProxy || options.EnableUDPProxy {
|
||||||
|
var err error
|
||||||
|
dev.proxyHandler, err = NewProxyHandler(ProxyHandlerOptions{
|
||||||
|
EnableTCP: options.EnableTCPProxy,
|
||||||
|
EnableUDP: options.EnableUDPProxy,
|
||||||
|
MTU: mtu,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to create proxy handler: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is enabled by default
|
sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is enabled by default
|
||||||
tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
|
tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
|
||||||
if tcpipErr != nil {
|
if tcpipErr != nil {
|
||||||
@@ -113,6 +108,13 @@ func CreateNetTUNWithOptions(localAddresses, dnsServers []netip.Addr, mtu int, o
|
|||||||
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
|
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize proxy handler after main stack is set up
|
||||||
|
if dev.proxyHandler != nil {
|
||||||
|
if err := dev.proxyHandler.Initialize(dev); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := dev.stack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
|
if err := dev.stack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
|
||||||
return nil, nil, fmt.Errorf("set ipv4 forwarding: %s", err)
|
return nil, nil, fmt.Errorf("set ipv4 forwarding: %s", err)
|
||||||
}
|
}
|
||||||
@@ -145,111 +147,6 @@ func CreateNetTUNWithOptions(localAddresses, dnsServers []netip.Addr, mtu int, o
|
|||||||
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
|
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add specific route for proxy network (10.20.20.0/24) to NIC 2
|
|
||||||
if options.EnableTCPProxy || options.EnableUDPProxy {
|
|
||||||
|
|
||||||
if options.EnableTCPProxy {
|
|
||||||
dev.tcpHandler = NewTCPHandler(dev.proxyStack)
|
|
||||||
if err := dev.tcpHandler.InstallTCPHandler(); err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("failed to install TCP handler: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.EnableUDPProxy {
|
|
||||||
dev.udpHandler = NewUDPHandler(dev.proxyStack)
|
|
||||||
if err := dev.udpHandler.InstallUDPHandler(); err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("failed to install UDP handler: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
dev.proxyNotifyHandle = dev.proxyEp.AddNotify(dev)
|
|
||||||
tcpipErr = dev.proxyStack.CreateNICWithOptions(1, dev.proxyEp, stack.NICOptions{
|
|
||||||
Disabled: false,
|
|
||||||
// If no queueing discipline was specified
|
|
||||||
// provide a stub implementation that just
|
|
||||||
// delegates to the lower link endpoint.
|
|
||||||
QDisc: nil,
|
|
||||||
})
|
|
||||||
if tcpipErr != nil {
|
|
||||||
return nil, nil, fmt.Errorf("CreateNIC 2 (proxy): %v", tcpipErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enable promiscuous mode ONLY on NIC 2
|
|
||||||
// This allows the NIC to accept packets destined for any IP address
|
|
||||||
if tcpipErr := dev.proxyStack.SetPromiscuousMode(1, true); tcpipErr != nil {
|
|
||||||
return nil, nil, fmt.Errorf("SetPromiscuousMode on NIC 2: %v", tcpipErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enable spoofing ONLY on NIC 2
|
|
||||||
// This allows the stack to send packets from any address, not just owned addresses
|
|
||||||
if tcpipErr := dev.proxyStack.SetSpoofing(1, true); tcpipErr != nil {
|
|
||||||
return nil, nil, fmt.Errorf("SetSpoofing on NIC 2: %v", tcpipErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// // Add a wildcard IPv4 address covering the 10.0.0.0/8 space so the stack can
|
|
||||||
// // synthesize temporary endpoints for any 10.x.y.z destination. This mimics
|
|
||||||
// // the tun2socks behaviour and is required once promiscuous+spoofing are turned on.
|
|
||||||
// wildcardAddr := tcpip.ProtocolAddress{
|
|
||||||
// Protocol: ipv4.ProtocolNumber,
|
|
||||||
// AddressWithPrefix: tcpip.AddressWithPrefix{
|
|
||||||
// Address: tcpip.AddrFrom4([4]byte{10, 0, 0, 1}),
|
|
||||||
// PrefixLen: 8,
|
|
||||||
// },
|
|
||||||
// }
|
|
||||||
// if tcpipErr = dev.stack.AddProtocolAddress(2, wildcardAddr, stack.AddressProperties{
|
|
||||||
// PEB: stack.CanBePrimaryEndpoint,
|
|
||||||
// }); tcpipErr != nil {
|
|
||||||
// return nil, nil, fmt.Errorf("Add wildcard proxy address: %v", tcpipErr)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Keep the real service IP (10.20.20.1/24) so existing clients that target the
|
|
||||||
// // gateway explicitly still resolve as before.
|
|
||||||
// proxyAddr := netip.MustParseAddr("10.20.20.1")
|
|
||||||
// protoAddr := tcpip.ProtocolAddress{
|
|
||||||
// Protocol: ipv4.ProtocolNumber,
|
|
||||||
// AddressWithPrefix: tcpip.AddressWithPrefix{
|
|
||||||
// Address: tcpip.AddrFromSlice(proxyAddr.AsSlice()),
|
|
||||||
// PrefixLen: 24,
|
|
||||||
// },
|
|
||||||
// }
|
|
||||||
// if tcpipErr = dev.stack.AddProtocolAddress(2, protoAddr, stack.AddressProperties{}); tcpipErr != nil {
|
|
||||||
// return nil, nil, fmt.Errorf("Add proxy service address: %v", tcpipErr)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// proxySubnet := netip.MustParsePrefix("10.20.20.0/24")
|
|
||||||
// proxyTcpipSubnet, err := tcpip.NewSubnet(
|
|
||||||
// tcpip.AddrFromSlice(proxySubnet.Addr().AsSlice()),
|
|
||||||
// tcpip.MaskFromBytes(net.CIDRMask(24, 32)),
|
|
||||||
// )
|
|
||||||
// if err != nil {
|
|
||||||
// return nil, nil, fmt.Errorf("failed to create proxy subnet: %v", err)
|
|
||||||
// }
|
|
||||||
|
|
||||||
dev.proxyStack.AddRoute(tcpip.Route{
|
|
||||||
Destination: header.IPv4EmptySubnet,
|
|
||||||
NIC: 1,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// print the stack routes table and interfaces for debugging
|
|
||||||
logger.Info("Stack configuration:")
|
|
||||||
|
|
||||||
// // Print NICs
|
|
||||||
// nics := dev.stack.NICInfo()
|
|
||||||
// for nicID, nicInfo := range nics {
|
|
||||||
// logger.Info("NIC %d: %s (MTU: %d)", nicID, nicInfo.Name, nicInfo.MTU)
|
|
||||||
// for _, addr := range nicInfo.ProtocolAddresses {
|
|
||||||
// logger.Info(" Address: %s", addr.AddressWithPrefix)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Print routing table
|
|
||||||
// routes := dev.stack.GetRouteTable()
|
|
||||||
// logger.Info("Routing table (%d routes):", len(routes))
|
|
||||||
// for i, route := range routes {
|
|
||||||
// logger.Info(" Route %d: %s -> NIC %d", i, route.Destination, route.NIC)
|
|
||||||
// }
|
|
||||||
|
|
||||||
dev.events <- tun.EventUp
|
dev.events <- tun.EventUp
|
||||||
return dev, (*Net)(dev), nil
|
return dev, (*Net)(dev), nil
|
||||||
}
|
}
|
||||||
@@ -287,32 +184,20 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
|
// Try to handle packet via proxy handler first
|
||||||
|
if tun.proxyHandler != nil && tun.proxyHandler.HandleIncomingPacket(packet) {
|
||||||
|
// Packet was handled by proxy
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Determine which NIC to inject the packet into based on destination IP
|
// Default handling: inject into main stack
|
||||||
targetEp := tun.ep // Default to NIC 1
|
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
|
||||||
|
|
||||||
switch packet[0] >> 4 {
|
switch packet[0] >> 4 {
|
||||||
case 4:
|
case 4:
|
||||||
// // Parse IPv4 header to check destination
|
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
||||||
if len(packet) >= header.IPv4MinimumSize {
|
|
||||||
ipv4Header := header.IPv4(packet)
|
|
||||||
dstIP := ipv4Header.DestinationAddress()
|
|
||||||
|
|
||||||
// Check if destination is in the proxy range (10.20.20.0/24)
|
|
||||||
// If so, inject into proxyEp (NIC 2) which has promiscuous mode
|
|
||||||
if tun.proxyEp != nil {
|
|
||||||
dstBytes := dstIP.As4()
|
|
||||||
// Check for 10.20.20.x
|
|
||||||
if dstBytes[0] == 10 && dstBytes[1] == 20 && dstBytes[2] == 20 {
|
|
||||||
targetEp = tun.proxyEp
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
targetEp.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
|
||||||
case 6:
|
case 6:
|
||||||
// For IPv6, always use NIC 1 for now
|
tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
|
||||||
targetEp.InjectInbound(header.IPv6ProtocolNumber, pkb)
|
|
||||||
default:
|
default:
|
||||||
return 0, syscall.EAFNOSUPPORT
|
return 0, syscall.EAFNOSUPPORT
|
||||||
}
|
}
|
||||||
@@ -320,86 +205,6 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
|
|||||||
return len(buf), nil
|
return len(buf), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// logPacketDetails parses and logs packet information
|
|
||||||
func logPacketDetails(pkt *stack.PacketBuffer, nicID int) {
|
|
||||||
netProto := pkt.NetworkProtocolNumber
|
|
||||||
var srcIP, dstIP string
|
|
||||||
var protocol string
|
|
||||||
var srcPort, dstPort uint16
|
|
||||||
|
|
||||||
// Parse network layer
|
|
||||||
switch netProto {
|
|
||||||
case header.IPv4ProtocolNumber:
|
|
||||||
if pkt.NetworkHeader().View().Size() >= header.IPv4MinimumSize {
|
|
||||||
ipv4 := header.IPv4(pkt.NetworkHeader().Slice())
|
|
||||||
srcIP = ipv4.SourceAddress().String()
|
|
||||||
dstIP = ipv4.DestinationAddress().String()
|
|
||||||
|
|
||||||
// Parse transport layer
|
|
||||||
switch ipv4.Protocol() {
|
|
||||||
case uint8(header.TCPProtocolNumber):
|
|
||||||
protocol = "TCP"
|
|
||||||
if pkt.TransportHeader().View().Size() >= header.TCPMinimumSize {
|
|
||||||
tcp := header.TCP(pkt.TransportHeader().Slice())
|
|
||||||
srcPort = tcp.SourcePort()
|
|
||||||
dstPort = tcp.DestinationPort()
|
|
||||||
}
|
|
||||||
case uint8(header.UDPProtocolNumber):
|
|
||||||
protocol = "UDP"
|
|
||||||
if pkt.TransportHeader().View().Size() >= header.UDPMinimumSize {
|
|
||||||
udp := header.UDP(pkt.TransportHeader().Slice())
|
|
||||||
srcPort = udp.SourcePort()
|
|
||||||
dstPort = udp.DestinationPort()
|
|
||||||
}
|
|
||||||
case uint8(header.ICMPv4ProtocolNumber):
|
|
||||||
protocol = "ICMPv4"
|
|
||||||
default:
|
|
||||||
protocol = fmt.Sprintf("Proto-%d", ipv4.Protocol())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case header.IPv6ProtocolNumber:
|
|
||||||
if pkt.NetworkHeader().View().Size() >= header.IPv6MinimumSize {
|
|
||||||
ipv6 := header.IPv6(pkt.NetworkHeader().Slice())
|
|
||||||
srcIP = ipv6.SourceAddress().String()
|
|
||||||
dstIP = ipv6.DestinationAddress().String()
|
|
||||||
|
|
||||||
// Parse transport layer
|
|
||||||
switch ipv6.TransportProtocol() {
|
|
||||||
case header.TCPProtocolNumber:
|
|
||||||
protocol = "TCP"
|
|
||||||
if pkt.TransportHeader().View().Size() >= header.TCPMinimumSize {
|
|
||||||
tcp := header.TCP(pkt.TransportHeader().Slice())
|
|
||||||
srcPort = tcp.SourcePort()
|
|
||||||
dstPort = tcp.DestinationPort()
|
|
||||||
}
|
|
||||||
case header.UDPProtocolNumber:
|
|
||||||
protocol = "UDP"
|
|
||||||
if pkt.TransportHeader().View().Size() >= header.UDPMinimumSize {
|
|
||||||
udp := header.UDP(pkt.TransportHeader().Slice())
|
|
||||||
srcPort = udp.SourcePort()
|
|
||||||
dstPort = udp.DestinationPort()
|
|
||||||
}
|
|
||||||
case header.ICMPv6ProtocolNumber:
|
|
||||||
protocol = "ICMPv6"
|
|
||||||
default:
|
|
||||||
protocol = fmt.Sprintf("Proto-%d", ipv6.TransportProtocol())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
protocol = fmt.Sprintf("Unknown-NetProto-%d", netProto)
|
|
||||||
}
|
|
||||||
|
|
||||||
packetSize := pkt.Size()
|
|
||||||
|
|
||||||
if srcPort > 0 && dstPort > 0 {
|
|
||||||
logger.Info("NIC %d packet: %s %s:%d -> %s:%d (size: %d bytes)",
|
|
||||||
nicID, protocol, srcIP, srcPort, dstIP, dstPort, packetSize)
|
|
||||||
} else {
|
|
||||||
logger.Info("NIC %d packet: %s %s -> %s (size: %d bytes)",
|
|
||||||
nicID, protocol, srcIP, dstIP, packetSize)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *netTun) WriteNotify() {
|
func (tun *netTun) WriteNotify() {
|
||||||
// Handle notifications from main endpoint (NIC 1)
|
// Handle notifications from main endpoint (NIC 1)
|
||||||
pkt := tun.ep.Read()
|
pkt := tun.ep.Read()
|
||||||
@@ -410,13 +215,11 @@ func (tun *netTun) WriteNotify() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle notifications from proxy endpoint (NIC 2) if it exists
|
// Handle notifications from proxy handler if it exists
|
||||||
// These are response packets from the proxied connections that need to go back to WireGuard
|
// These are response packets from the proxied connections that need to go back to WireGuard
|
||||||
if tun.proxyEp != nil {
|
if tun.proxyHandler != nil {
|
||||||
pkt = tun.proxyEp.Read()
|
view := tun.proxyHandler.ReadOutgoingPacket()
|
||||||
if pkt != nil {
|
if view != nil {
|
||||||
view := pkt.ToView()
|
|
||||||
pkt.DecRef()
|
|
||||||
tun.incomingPacket <- view
|
tun.incomingPacket <- view
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -426,17 +229,15 @@ func (tun *netTun) WriteNotify() {
|
|||||||
func (tun *netTun) Close() error {
|
func (tun *netTun) Close() error {
|
||||||
tun.stack.RemoveNIC(1)
|
tun.stack.RemoveNIC(1)
|
||||||
|
|
||||||
// Clean up proxy NIC if it exists
|
|
||||||
if tun.proxyEp != nil {
|
|
||||||
tun.stack.RemoveNIC(2)
|
|
||||||
tun.proxyEp.RemoveNotify(tun.proxyNotifyHandle)
|
|
||||||
tun.proxyEp.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
tun.stack.Close()
|
tun.stack.Close()
|
||||||
tun.ep.RemoveNotify(tun.notifyHandle)
|
tun.ep.RemoveNotify(tun.notifyHandle)
|
||||||
tun.ep.Close()
|
tun.ep.Close()
|
||||||
|
|
||||||
|
// Clean up proxy handler if it exists
|
||||||
|
if tun.proxyHandler != nil {
|
||||||
|
tun.proxyHandler.Close()
|
||||||
|
}
|
||||||
|
|
||||||
if tun.events != nil {
|
if tun.events != nil {
|
||||||
close(tun.events)
|
close(tun.events)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
"github.com/fosrl/newt/holepunch"
|
"github.com/fosrl/newt/holepunch"
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/netstack2"
|
"github.com/fosrl/newt/netstack2"
|
||||||
"github.com/fosrl/newt/proxy"
|
|
||||||
"github.com/fosrl/newt/util"
|
"github.com/fosrl/newt/util"
|
||||||
"github.com/fosrl/newt/websocket"
|
"github.com/fosrl/newt/websocket"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
@@ -88,34 +87,12 @@ type WireGuardService struct {
|
|||||||
onNetstackClose func()
|
onNetstackClose func()
|
||||||
othertnet *netstack.Net
|
othertnet *netstack.Net
|
||||||
// Proxy manager for tunnel
|
// Proxy manager for tunnel
|
||||||
proxyManager *proxy.ProxyManager
|
TunnelIP string
|
||||||
TunnelIP string
|
|
||||||
// Shared bind and holepunch manager
|
// Shared bind and holepunch manager
|
||||||
sharedBind *bind.SharedBind
|
sharedBind *bind.SharedBind
|
||||||
holePunchManager *holepunch.Manager
|
holePunchManager *holepunch.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProxyManager returns the proxy manager for this WireGuardService
|
|
||||||
func (s *WireGuardService) GetProxyManager() *proxy.ProxyManager {
|
|
||||||
return s.proxyManager
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddProxyTarget adds a target to the proxy manager
|
|
||||||
func (s *WireGuardService) AddProxyTarget(proto, listenIP string, port int, targetAddr string) error {
|
|
||||||
if s.proxyManager == nil {
|
|
||||||
return fmt.Errorf("proxy manager not initialized")
|
|
||||||
}
|
|
||||||
return s.proxyManager.AddTarget(proto, listenIP, port, targetAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveProxyTarget removes a target from the proxy manager
|
|
||||||
func (s *WireGuardService) RemoveProxyTarget(proto, listenIP string, port int) error {
|
|
||||||
if s.proxyManager == nil {
|
|
||||||
return fmt.Errorf("proxy manager not initialized")
|
|
||||||
}
|
|
||||||
return s.proxyManager.RemoveTarget(proto, listenIP, port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string) (*WireGuardService, error) {
|
func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string) (*WireGuardService, error) {
|
||||||
var key wgtypes.Key
|
var key wgtypes.Key
|
||||||
var err error
|
var err error
|
||||||
@@ -189,7 +166,6 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
|
|||||||
lastReadings: make(map[string]PeerReading),
|
lastReadings: make(map[string]PeerReading),
|
||||||
Port: port,
|
Port: port,
|
||||||
dns: dnsAddrs,
|
dns: dnsAddrs,
|
||||||
proxyManager: proxy.NewProxyManagerWithoutTNet(),
|
|
||||||
sharedBind: sharedBind,
|
sharedBind: sharedBind,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -202,10 +178,6 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
|
|||||||
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
|
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
|
||||||
wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer)
|
wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer)
|
||||||
wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer)
|
wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer)
|
||||||
wsClient.RegisterHandler("newt/wg/tcp/add", service.addTcpTarget)
|
|
||||||
wsClient.RegisterHandler("newt/wg/udp/add", service.addUdpTarget)
|
|
||||||
wsClient.RegisterHandler("newt/wg/udp/remove", service.removeUdpTarget)
|
|
||||||
wsClient.RegisterHandler("newt/wg/tcp/remove", service.removeTcpTarget)
|
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
@@ -218,86 +190,6 @@ func (s *WireGuardService) ReportRTT(seconds float64) {
|
|||||||
telemetry.ObserveTunnelLatency(context.Background(), s.serverPubKey, "wireguard", seconds)
|
telemetry.ObserveTunnelLatency(context.Background(), s.serverPubKey, "wireguard", seconds)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) addTcpTarget(msg websocket.WSMessage) {
|
|
||||||
logger.Debug("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if s.TunnelIP == "" || s.proxyManager == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", targetData)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) addUdpTarget(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if s.TunnelIP == "" || s.proxyManager == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", targetData)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) removeUdpTarget(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if s.TunnelIP == "" || s.proxyManager == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "udp", targetData)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) removeTcpTarget(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if s.TunnelIP == "" || s.proxyManager == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "tcp", targetData)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) SetOthertnet(tnet *netstack.Net) {
|
func (s *WireGuardService) SetOthertnet(tnet *netstack.Net) {
|
||||||
s.othertnet = tnet
|
s.othertnet = tnet
|
||||||
}
|
}
|
||||||
@@ -435,18 +327,6 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
|||||||
if err := s.ensureWireguardPeers(config.Peers); err != nil {
|
if err := s.ensureWireguardPeers(config.Peers); err != nil {
|
||||||
logger.Error("Failed to ensure WireGuard peers: %v", err)
|
logger.Error("Failed to ensure WireGuard peers: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// add the targets if there are any
|
|
||||||
if len(config.Targets.TCP) > 0 {
|
|
||||||
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", TargetData{Targets: config.Targets.TCP})
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(config.Targets.UDP) > 0 {
|
|
||||||
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", TargetData{Targets: config.Targets.UDP})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create ProxyManager for this tunnel
|
|
||||||
s.proxyManager.Start()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||||
@@ -484,7 +364,7 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return fmt.Errorf("failed to create TUN device: %v", err)
|
return fmt.Errorf("failed to create TUN device: %v", err)
|
||||||
}
|
}
|
||||||
// s.proxyManager.SetTNet(s.tnet)
|
|
||||||
s.TunnelIP = tunnelIP.String()
|
s.TunnelIP = tunnelIP.String()
|
||||||
|
|
||||||
// Create WireGuard device using the shared bind
|
// Create WireGuard device using the shared bind
|
||||||
@@ -921,169 +801,6 @@ func (s *WireGuardService) reportPeerBandwidth() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
|
|
||||||
var replace = false
|
|
||||||
for _, t := range targetData.Targets {
|
|
||||||
// Split the first number off of the target with : separator and use as the port
|
|
||||||
parts := strings.Split(t, ":")
|
|
||||||
if len(parts) != 3 {
|
|
||||||
logger.Info("Invalid target format: %s", t)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the port as an int
|
|
||||||
port := 0
|
|
||||||
_, err := fmt.Sscanf(parts[0], "%d", &port)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Invalid port: %s", parts[0])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if action == "add" {
|
|
||||||
target := parts[1] + ":" + parts[2]
|
|
||||||
|
|
||||||
// Call updown script if provided
|
|
||||||
processedTarget := target
|
|
||||||
|
|
||||||
// Only remove the specific target if it exists
|
|
||||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
|
||||||
if err != nil {
|
|
||||||
// Ignore "target not found" errors as this is expected for new targets
|
|
||||||
if !strings.Contains(err.Error(), "target not found") {
|
|
||||||
logger.Error("Failed to remove existing target: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
replace = true // We successfully removed an existing target
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the new target
|
|
||||||
pm.AddTarget(proto, tunnelIP, port, processedTarget)
|
|
||||||
|
|
||||||
} else if action == "remove" {
|
|
||||||
logger.Info("Removing target with port %d", port)
|
|
||||||
|
|
||||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to remove target: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if replace {
|
|
||||||
// If we replaced any targets, we need to hot swap the netstack
|
|
||||||
if err := s.ReplaceNetstack(); err != nil {
|
|
||||||
logger.Error("Failed to replace netstack after updating targets: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
logger.Info("Netstack replaced successfully after updating targets")
|
|
||||||
} else {
|
|
||||||
logger.Info("No targets updated, no netstack replacement needed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseTargetData(data interface{}) (TargetData, error) {
|
|
||||||
var targetData TargetData
|
|
||||||
jsonData, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error marshaling data: %v", err)
|
|
||||||
return targetData, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(jsonData, &targetData); err != nil {
|
|
||||||
logger.Info("Error unmarshaling target data: %v", err)
|
|
||||||
return targetData, err
|
|
||||||
}
|
|
||||||
return targetData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add this method to WireGuardService
|
|
||||||
func (s *WireGuardService) ReplaceNetstack() error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
if s.device == nil || s.tun == nil {
|
|
||||||
return fmt.Errorf("WireGuard device not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the current tunnel IP from the existing config
|
|
||||||
parts := strings.Split(s.config.IpAddress, "/")
|
|
||||||
if len(parts) != 2 {
|
|
||||||
return fmt.Errorf("invalid IP address format: %s", s.config.IpAddress)
|
|
||||||
}
|
|
||||||
tunnelIP := netip.MustParseAddr(parts[0])
|
|
||||||
|
|
||||||
// Stop the proxy manager temporarily
|
|
||||||
s.proxyManager.Stop()
|
|
||||||
|
|
||||||
// Create new TUN device and netstack with new DNS
|
|
||||||
newTun, newTnet, err := netstack2.CreateNetTUN(
|
|
||||||
[]netip.Addr{tunnelIP},
|
|
||||||
s.dns,
|
|
||||||
s.mtu)
|
|
||||||
if err != nil {
|
|
||||||
// Restart proxy manager with old tnet on failure
|
|
||||||
s.proxyManager.Start()
|
|
||||||
return fmt.Errorf("failed to create new TUN device: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get current device config before closing
|
|
||||||
currentConfig, err := s.device.IpcGet()
|
|
||||||
if err != nil {
|
|
||||||
newTun.Close()
|
|
||||||
s.proxyManager.Start()
|
|
||||||
return fmt.Errorf("failed to get current device config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter out read-only fields from the config
|
|
||||||
filteredConfig := s.filterReadOnlyFields(currentConfig)
|
|
||||||
|
|
||||||
// if onNetstackClose callback is set, call it
|
|
||||||
if s.onNetstackClose != nil {
|
|
||||||
s.onNetstackClose()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close old device (this closes the old TUN device)
|
|
||||||
s.device.Close()
|
|
||||||
|
|
||||||
// Update references
|
|
||||||
s.tun = newTun
|
|
||||||
s.tnet = newTnet
|
|
||||||
|
|
||||||
// Create new WireGuard device with same shared bind
|
|
||||||
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
|
|
||||||
device.LogLevelSilent,
|
|
||||||
"wireguard: ",
|
|
||||||
))
|
|
||||||
|
|
||||||
// Restore the configuration (without read-only fields)
|
|
||||||
err = s.device.IpcSet(filteredConfig)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to restore WireGuard configuration: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bring up the device
|
|
||||||
err = s.device.Up()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to bring up new WireGuard device: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update proxy manager with new tnet and restart
|
|
||||||
// s.proxyManager.SetTNet(s.tnet)
|
|
||||||
s.proxyManager.Start()
|
|
||||||
|
|
||||||
s.proxyManager.PrintTargets()
|
|
||||||
|
|
||||||
// Call the netstack ready callback if set
|
|
||||||
if s.onNetstackReady != nil {
|
|
||||||
go s.onNetstackReady(s.tnet)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// filterReadOnlyFields removes read-only fields from WireGuard IPC configuration
|
// filterReadOnlyFields removes read-only fields from WireGuard IPC configuration
|
||||||
func (s *WireGuardService) filterReadOnlyFields(config string) string {
|
func (s *WireGuardService) filterReadOnlyFields(config string) string {
|
||||||
lines := strings.Split(config, "\n")
|
lines := strings.Split(config, "\n")
|
||||||
|
|||||||
Reference in New Issue
Block a user