diff --git a/netstack2/proxy.go b/netstack2/proxy.go new file mode 100644 index 0000000..2a1fa03 --- /dev/null +++ b/netstack2/proxy.go @@ -0,0 +1,321 @@ +package netstack2 + +import ( + "fmt" + "net/netip" + "sync" + + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +// PortRange represents an allowed range of ports (inclusive) +type PortRange struct { + Min uint16 + Max uint16 +} + +// SubnetRule represents a subnet with optional port restrictions +type SubnetRule struct { + Prefix netip.Prefix + PortRanges []PortRange // empty slice means all ports allowed +} + +// SubnetLookup provides fast IP subnet and port matching +type SubnetLookup struct { + mu sync.RWMutex + rules []SubnetRule +} + +// NewSubnetLookup creates a new subnet lookup table +func NewSubnetLookup() *SubnetLookup { + return &SubnetLookup{ + rules: make([]SubnetRule, 0), + } +} + +// AddSubnet adds a subnet to the lookup table with optional port restrictions +// If portRanges is nil or empty, all ports are allowed for this subnet +func (sl *SubnetLookup) AddSubnet(prefix netip.Prefix, portRanges []PortRange) { + sl.mu.Lock() + defer sl.mu.Unlock() + + sl.rules = append(sl.rules, SubnetRule{ + Prefix: prefix, + PortRanges: portRanges, + }) +} + +// RemoveSubnet removes a subnet from the lookup table +func (sl *SubnetLookup) RemoveSubnet(prefix netip.Prefix) { + sl.mu.Lock() + defer sl.mu.Unlock() + + for i, rule := range sl.rules { + if rule.Prefix == prefix { + sl.rules = append(sl.rules[:i], sl.rules[i+1:]...) + return + } + } +} + +// Match checks if an IP and port match any subnet rule +// Returns true if the IP is in a matching subnet AND the port is in an allowed range +func (sl *SubnetLookup) Match(ip netip.Addr, port uint16) bool { + sl.mu.RLock() + defer sl.mu.RUnlock() + + for _, rule := range sl.rules { + if rule.Prefix.Contains(ip) { + // If no port ranges specified, all ports are allowed + if len(rule.PortRanges) == 0 { + return true + } + + // Check if port is in any of the allowed ranges + for _, pr := range rule.PortRanges { + if port >= pr.Min && port <= pr.Max { + return true + } + } + } + } + + return false +} + +// ProxyHandler handles packet injection and extraction for promiscuous mode +type ProxyHandler struct { + proxyStack *stack.Stack + proxyEp *channel.Endpoint + proxyNotifyHandle *channel.NotificationHandle + tcpHandler *TCPHandler + udpHandler *UDPHandler + subnetLookup *SubnetLookup + enabled bool +} + +// ProxyHandlerOptions configures the proxy handler +type ProxyHandlerOptions struct { + EnableTCP bool + EnableUDP bool + MTU int +} + +// NewProxyHandler creates a new proxy handler for promiscuous mode +func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { + if !options.EnableTCP && !options.EnableUDP { + return nil, nil // No proxy needed + } + + handler := &ProxyHandler{ + enabled: true, + subnetLookup: NewSubnetLookup(), + proxyEp: channel.New(1024, uint32(options.MTU), ""), + proxyStack: stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + ipv6.NewProtocol, + }, + TransportProtocols: []stack.TransportProtocolFactory{ + tcp.NewProtocol, + udp.NewProtocol, + icmp.NewProtocol4, + icmp.NewProtocol6, + }, + }), + } + + // Initialize TCP handler if enabled + if options.EnableTCP { + handler.tcpHandler = NewTCPHandler(handler.proxyStack) + if err := handler.tcpHandler.InstallTCPHandler(); err != nil { + return nil, fmt.Errorf("failed to install TCP handler: %v", err) + } + } + + // Initialize UDP handler if enabled + if options.EnableUDP { + handler.udpHandler = NewUDPHandler(handler.proxyStack) + if err := handler.udpHandler.InstallUDPHandler(); err != nil { + return nil, fmt.Errorf("failed to install UDP handler: %v", err) + } + } + + // Example 1: Add a subnet with no port restrictions (all ports allowed) + // This accepts all traffic to 10.20.20.0/24 + subnet1 := netip.MustParsePrefix("10.20.20.0/24") + handler.AddSubnetRule(subnet1, nil) + + // Example 2: Add a subnet with specific port ranges + // This accepts traffic to 192.168.1.0/24 only on ports 80, 443, and 8000-9000 + subnet2 := netip.MustParsePrefix("10.20.21.21/32") + handler.AddSubnetRule(subnet2, []PortRange{ + {Min: 12000, Max: 12001}, + {Min: 8000, Max: 8000}, + }) + + return handler, nil +} + +// AddSubnetRule adds a subnet with optional port restrictions to the proxy handler +// If portRanges is nil or empty, all ports are allowed for this subnet +func (p *ProxyHandler) AddSubnetRule(prefix netip.Prefix, portRanges []PortRange) { + if p == nil || !p.enabled { + return + } + p.subnetLookup.AddSubnet(prefix, portRanges) +} + +// RemoveSubnetRule removes a subnet from the proxy handler +func (p *ProxyHandler) RemoveSubnetRule(prefix netip.Prefix) { + if p == nil || !p.enabled { + return + } + p.subnetLookup.RemoveSubnet(prefix) +} + +// Initialize sets up the promiscuous NIC with the netTun's notification system +func (p *ProxyHandler) Initialize(notifiable channel.Notification) error { + if p == nil || !p.enabled { + return nil + } + + // Add notification handler + p.proxyNotifyHandle = p.proxyEp.AddNotify(notifiable) + + // Create NIC with promiscuous mode + tcpipErr := p.proxyStack.CreateNICWithOptions(1, p.proxyEp, stack.NICOptions{ + Disabled: false, + QDisc: nil, + }) + if tcpipErr != nil { + return fmt.Errorf("CreateNIC (proxy): %v", tcpipErr) + } + + // Enable promiscuous mode - accepts packets for any destination IP + if tcpipErr := p.proxyStack.SetPromiscuousMode(1, true); tcpipErr != nil { + return fmt.Errorf("SetPromiscuousMode: %v", tcpipErr) + } + + // Enable spoofing - allows sending packets from any source IP + if tcpipErr := p.proxyStack.SetSpoofing(1, true); tcpipErr != nil { + return fmt.Errorf("SetSpoofing: %v", tcpipErr) + } + + // Add default route + p.proxyStack.AddRoute(tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: 1, + }) + + return nil +} + +// HandleIncomingPacket processes incoming packets and determines if they should +// be injected into the proxy stack +func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { + if p == nil || !p.enabled { + return false + } + + // Check minimum packet size + if len(packet) < header.IPv4MinimumSize { + return false + } + + // Only handle IPv4 for now + if packet[0]>>4 != 4 { + return false + } + + // Parse IPv4 header + ipv4Header := header.IPv4(packet) + dstIP := ipv4Header.DestinationAddress() + + // Convert gvisor tcpip.Address to netip.Addr + dstBytes := dstIP.As4() + addr := netip.AddrFrom4(dstBytes) + + // Parse transport layer to get destination port + var dstPort uint16 + protocol := ipv4Header.TransportProtocol() + headerLen := int(ipv4Header.HeaderLength()) + + // Extract port based on protocol + switch protocol { + case header.TCPProtocolNumber: + if len(packet) < headerLen+header.TCPMinimumSize { + return false + } + tcpHeader := header.TCP(packet[headerLen:]) + dstPort = tcpHeader.DestinationPort() + case header.UDPProtocolNumber: + if len(packet) < headerLen+header.UDPMinimumSize { + return false + } + udpHeader := header.UDP(packet[headerLen:]) + dstPort = udpHeader.DestinationPort() + default: + // For other protocols (ICMP, etc.), use port 0 (must match rules with no port restrictions) + dstPort = 0 + } + + // Check if the destination IP and port match any subnet rule + if p.subnetLookup.Match(addr, dstPort) { + // Inject into proxy stack + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(packet), + }) + p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb) + return true + } + + return false +} + +// ReadOutgoingPacket reads packets from the proxy stack that need to be +// sent back through the tunnel +func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View { + if p == nil || !p.enabled { + return nil + } + + pkt := p.proxyEp.Read() + if pkt != nil { + view := pkt.ToView() + pkt.DecRef() + return view + } + + return nil +} + +// Close cleans up the proxy handler resources +func (p *ProxyHandler) Close() error { + if p == nil || !p.enabled { + return nil + } + + if p.proxyStack != nil { + p.proxyStack.RemoveNIC(1) + p.proxyStack.Close() + } + + if p.proxyEp != nil { + if p.proxyNotifyHandle != nil { + p.proxyEp.RemoveNotify(p.proxyNotifyHandle) + } + p.proxyEp.Close() + } + + return nil +} diff --git a/netstack2/tun.go b/netstack2/tun.go index ca2511c..80dac39 100644 --- a/netstack2/tun.go +++ b/netstack2/tun.go @@ -22,7 +22,6 @@ import ( "syscall" "time" - "github.com/fosrl/newt/logger" "golang.zx2c4.com/wireguard/tun" "golang.org/x/net/dns/dnsmessage" @@ -41,19 +40,15 @@ import ( ) type netTun struct { - ep *channel.Endpoint - proxyEp *channel.Endpoint // Separate endpoint for promiscuous mode - stack *stack.Stack - proxyStack *stack.Stack // Separate stack for proxy endpoint - events chan tun.Event - notifyHandle *channel.NotificationHandle - proxyNotifyHandle *channel.NotificationHandle // Notify handle for proxy endpoint - incomingPacket chan *buffer.View - mtu int - dnsServers []netip.Addr - hasV4, hasV6 bool - tcpHandler *TCPHandler - udpHandler *UDPHandler + ep *channel.Endpoint + stack *stack.Stack + events chan tun.Event + notifyHandle *channel.NotificationHandle + incomingPacket chan *buffer.View + mtu int + dnsServers []netip.Addr + hasV4, hasV6 bool + proxyHandler *ProxyHandler // Handles promiscuous mode packet processing } type Net netTun @@ -80,27 +75,27 @@ func CreateNetTUNWithOptions(localAddresses, dnsServers []netip.Addr, mtu int, o HandleLocal: true, } dev := &netTun{ - ep: channel.New(1024, uint32(mtu), ""), - proxyEp: channel.New(1024, uint32(mtu), ""), - stack: stack.New(stackOpts), - proxyStack: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocol, - ipv6.NewProtocol, - }, - TransportProtocols: []stack.TransportProtocolFactory{ - tcp.NewProtocol, - udp.NewProtocol, - icmp.NewProtocol4, - icmp.NewProtocol6, - }, - }), + ep: channel.New(1024, uint32(mtu), ""), + stack: stack.New(stackOpts), events: make(chan tun.Event, 10), incomingPacket: make(chan *buffer.View), dnsServers: dnsServers, mtu: mtu, } + // Initialize proxy handler if TCP or UDP proxying is enabled + if options.EnableTCPProxy || options.EnableUDPProxy { + var err error + dev.proxyHandler, err = NewProxyHandler(ProxyHandlerOptions{ + EnableTCP: options.EnableTCPProxy, + EnableUDP: options.EnableUDPProxy, + MTU: mtu, + }) + if err != nil { + return nil, nil, fmt.Errorf("failed to create proxy handler: %v", err) + } + } + sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is enabled by default tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) if tcpipErr != nil { @@ -113,6 +108,13 @@ func CreateNetTUNWithOptions(localAddresses, dnsServers []netip.Addr, mtu int, o return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) } + // Initialize proxy handler after main stack is set up + if dev.proxyHandler != nil { + if err := dev.proxyHandler.Initialize(dev); err != nil { + return nil, nil, err + } + } + if err := dev.stack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { return nil, nil, fmt.Errorf("set ipv4 forwarding: %s", err) } @@ -145,111 +147,6 @@ func CreateNetTUNWithOptions(localAddresses, dnsServers []netip.Addr, mtu int, o dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) } - // Add specific route for proxy network (10.20.20.0/24) to NIC 2 - if options.EnableTCPProxy || options.EnableUDPProxy { - - if options.EnableTCPProxy { - dev.tcpHandler = NewTCPHandler(dev.proxyStack) - if err := dev.tcpHandler.InstallTCPHandler(); err != nil { - return nil, nil, fmt.Errorf("failed to install TCP handler: %v", err) - } - } - - if options.EnableUDPProxy { - dev.udpHandler = NewUDPHandler(dev.proxyStack) - if err := dev.udpHandler.InstallUDPHandler(); err != nil { - return nil, nil, fmt.Errorf("failed to install UDP handler: %v", err) - } - } - - dev.proxyNotifyHandle = dev.proxyEp.AddNotify(dev) - tcpipErr = dev.proxyStack.CreateNICWithOptions(1, dev.proxyEp, stack.NICOptions{ - Disabled: false, - // If no queueing discipline was specified - // provide a stub implementation that just - // delegates to the lower link endpoint. - QDisc: nil, - }) - if tcpipErr != nil { - return nil, nil, fmt.Errorf("CreateNIC 2 (proxy): %v", tcpipErr) - } - - // Enable promiscuous mode ONLY on NIC 2 - // This allows the NIC to accept packets destined for any IP address - if tcpipErr := dev.proxyStack.SetPromiscuousMode(1, true); tcpipErr != nil { - return nil, nil, fmt.Errorf("SetPromiscuousMode on NIC 2: %v", tcpipErr) - } - - // Enable spoofing ONLY on NIC 2 - // This allows the stack to send packets from any address, not just owned addresses - if tcpipErr := dev.proxyStack.SetSpoofing(1, true); tcpipErr != nil { - return nil, nil, fmt.Errorf("SetSpoofing on NIC 2: %v", tcpipErr) - } - - // // Add a wildcard IPv4 address covering the 10.0.0.0/8 space so the stack can - // // synthesize temporary endpoints for any 10.x.y.z destination. This mimics - // // the tun2socks behaviour and is required once promiscuous+spoofing are turned on. - // wildcardAddr := tcpip.ProtocolAddress{ - // Protocol: ipv4.ProtocolNumber, - // AddressWithPrefix: tcpip.AddressWithPrefix{ - // Address: tcpip.AddrFrom4([4]byte{10, 0, 0, 1}), - // PrefixLen: 8, - // }, - // } - // if tcpipErr = dev.stack.AddProtocolAddress(2, wildcardAddr, stack.AddressProperties{ - // PEB: stack.CanBePrimaryEndpoint, - // }); tcpipErr != nil { - // return nil, nil, fmt.Errorf("Add wildcard proxy address: %v", tcpipErr) - // } - - // // Keep the real service IP (10.20.20.1/24) so existing clients that target the - // // gateway explicitly still resolve as before. - // proxyAddr := netip.MustParseAddr("10.20.20.1") - // protoAddr := tcpip.ProtocolAddress{ - // Protocol: ipv4.ProtocolNumber, - // AddressWithPrefix: tcpip.AddressWithPrefix{ - // Address: tcpip.AddrFromSlice(proxyAddr.AsSlice()), - // PrefixLen: 24, - // }, - // } - // if tcpipErr = dev.stack.AddProtocolAddress(2, protoAddr, stack.AddressProperties{}); tcpipErr != nil { - // return nil, nil, fmt.Errorf("Add proxy service address: %v", tcpipErr) - // } - - // proxySubnet := netip.MustParsePrefix("10.20.20.0/24") - // proxyTcpipSubnet, err := tcpip.NewSubnet( - // tcpip.AddrFromSlice(proxySubnet.Addr().AsSlice()), - // tcpip.MaskFromBytes(net.CIDRMask(24, 32)), - // ) - // if err != nil { - // return nil, nil, fmt.Errorf("failed to create proxy subnet: %v", err) - // } - - dev.proxyStack.AddRoute(tcpip.Route{ - Destination: header.IPv4EmptySubnet, - NIC: 1, - }) - } - - // print the stack routes table and interfaces for debugging - logger.Info("Stack configuration:") - - // // Print NICs - // nics := dev.stack.NICInfo() - // for nicID, nicInfo := range nics { - // logger.Info("NIC %d: %s (MTU: %d)", nicID, nicInfo.Name, nicInfo.MTU) - // for _, addr := range nicInfo.ProtocolAddresses { - // logger.Info(" Address: %s", addr.AddressWithPrefix) - // } - // } - - // // Print routing table - // routes := dev.stack.GetRouteTable() - // logger.Info("Routing table (%d routes):", len(routes)) - // for i, route := range routes { - // logger.Info(" Route %d: %s -> NIC %d", i, route.Destination, route.NIC) - // } - dev.events <- tun.EventUp return dev, (*Net)(dev), nil } @@ -287,32 +184,20 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { continue } - pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) + // Try to handle packet via proxy handler first + if tun.proxyHandler != nil && tun.proxyHandler.HandleIncomingPacket(packet) { + // Packet was handled by proxy + continue + } - // Determine which NIC to inject the packet into based on destination IP - targetEp := tun.ep // Default to NIC 1 + // Default handling: inject into main stack + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) switch packet[0] >> 4 { case 4: - // // Parse IPv4 header to check destination - if len(packet) >= header.IPv4MinimumSize { - ipv4Header := header.IPv4(packet) - dstIP := ipv4Header.DestinationAddress() - - // Check if destination is in the proxy range (10.20.20.0/24) - // If so, inject into proxyEp (NIC 2) which has promiscuous mode - if tun.proxyEp != nil { - dstBytes := dstIP.As4() - // Check for 10.20.20.x - if dstBytes[0] == 10 && dstBytes[1] == 20 && dstBytes[2] == 20 { - targetEp = tun.proxyEp - } - } - } - targetEp.InjectInbound(header.IPv4ProtocolNumber, pkb) + tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) case 6: - // For IPv6, always use NIC 1 for now - targetEp.InjectInbound(header.IPv6ProtocolNumber, pkb) + tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) default: return 0, syscall.EAFNOSUPPORT } @@ -320,86 +205,6 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { return len(buf), nil } -// logPacketDetails parses and logs packet information -func logPacketDetails(pkt *stack.PacketBuffer, nicID int) { - netProto := pkt.NetworkProtocolNumber - var srcIP, dstIP string - var protocol string - var srcPort, dstPort uint16 - - // Parse network layer - switch netProto { - case header.IPv4ProtocolNumber: - if pkt.NetworkHeader().View().Size() >= header.IPv4MinimumSize { - ipv4 := header.IPv4(pkt.NetworkHeader().Slice()) - srcIP = ipv4.SourceAddress().String() - dstIP = ipv4.DestinationAddress().String() - - // Parse transport layer - switch ipv4.Protocol() { - case uint8(header.TCPProtocolNumber): - protocol = "TCP" - if pkt.TransportHeader().View().Size() >= header.TCPMinimumSize { - tcp := header.TCP(pkt.TransportHeader().Slice()) - srcPort = tcp.SourcePort() - dstPort = tcp.DestinationPort() - } - case uint8(header.UDPProtocolNumber): - protocol = "UDP" - if pkt.TransportHeader().View().Size() >= header.UDPMinimumSize { - udp := header.UDP(pkt.TransportHeader().Slice()) - srcPort = udp.SourcePort() - dstPort = udp.DestinationPort() - } - case uint8(header.ICMPv4ProtocolNumber): - protocol = "ICMPv4" - default: - protocol = fmt.Sprintf("Proto-%d", ipv4.Protocol()) - } - } - case header.IPv6ProtocolNumber: - if pkt.NetworkHeader().View().Size() >= header.IPv6MinimumSize { - ipv6 := header.IPv6(pkt.NetworkHeader().Slice()) - srcIP = ipv6.SourceAddress().String() - dstIP = ipv6.DestinationAddress().String() - - // Parse transport layer - switch ipv6.TransportProtocol() { - case header.TCPProtocolNumber: - protocol = "TCP" - if pkt.TransportHeader().View().Size() >= header.TCPMinimumSize { - tcp := header.TCP(pkt.TransportHeader().Slice()) - srcPort = tcp.SourcePort() - dstPort = tcp.DestinationPort() - } - case header.UDPProtocolNumber: - protocol = "UDP" - if pkt.TransportHeader().View().Size() >= header.UDPMinimumSize { - udp := header.UDP(pkt.TransportHeader().Slice()) - srcPort = udp.SourcePort() - dstPort = udp.DestinationPort() - } - case header.ICMPv6ProtocolNumber: - protocol = "ICMPv6" - default: - protocol = fmt.Sprintf("Proto-%d", ipv6.TransportProtocol()) - } - } - default: - protocol = fmt.Sprintf("Unknown-NetProto-%d", netProto) - } - - packetSize := pkt.Size() - - if srcPort > 0 && dstPort > 0 { - logger.Info("NIC %d packet: %s %s:%d -> %s:%d (size: %d bytes)", - nicID, protocol, srcIP, srcPort, dstIP, dstPort, packetSize) - } else { - logger.Info("NIC %d packet: %s %s -> %s (size: %d bytes)", - nicID, protocol, srcIP, dstIP, packetSize) - } -} - func (tun *netTun) WriteNotify() { // Handle notifications from main endpoint (NIC 1) pkt := tun.ep.Read() @@ -410,13 +215,11 @@ func (tun *netTun) WriteNotify() { return } - // Handle notifications from proxy endpoint (NIC 2) if it exists + // Handle notifications from proxy handler if it exists // These are response packets from the proxied connections that need to go back to WireGuard - if tun.proxyEp != nil { - pkt = tun.proxyEp.Read() - if pkt != nil { - view := pkt.ToView() - pkt.DecRef() + if tun.proxyHandler != nil { + view := tun.proxyHandler.ReadOutgoingPacket() + if view != nil { tun.incomingPacket <- view return } @@ -426,17 +229,15 @@ func (tun *netTun) WriteNotify() { func (tun *netTun) Close() error { tun.stack.RemoveNIC(1) - // Clean up proxy NIC if it exists - if tun.proxyEp != nil { - tun.stack.RemoveNIC(2) - tun.proxyEp.RemoveNotify(tun.proxyNotifyHandle) - tun.proxyEp.Close() - } - tun.stack.Close() tun.ep.RemoveNotify(tun.notifyHandle) tun.ep.Close() + // Clean up proxy handler if it exists + if tun.proxyHandler != nil { + tun.proxyHandler.Close() + } + if tun.events != nil { close(tun.events) } diff --git a/wgnetstack/wgnetstack.go b/wgnetstack/wgnetstack.go index 13edbfd..d1604db 100644 --- a/wgnetstack/wgnetstack.go +++ b/wgnetstack/wgnetstack.go @@ -18,7 +18,6 @@ import ( "github.com/fosrl/newt/holepunch" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/netstack2" - "github.com/fosrl/newt/proxy" "github.com/fosrl/newt/util" "github.com/fosrl/newt/websocket" "golang.zx2c4.com/wireguard/device" @@ -88,34 +87,12 @@ type WireGuardService struct { onNetstackClose func() othertnet *netstack.Net // Proxy manager for tunnel - proxyManager *proxy.ProxyManager - TunnelIP string + TunnelIP string // Shared bind and holepunch manager sharedBind *bind.SharedBind holePunchManager *holepunch.Manager } -// GetProxyManager returns the proxy manager for this WireGuardService -func (s *WireGuardService) GetProxyManager() *proxy.ProxyManager { - return s.proxyManager -} - -// AddProxyTarget adds a target to the proxy manager -func (s *WireGuardService) AddProxyTarget(proto, listenIP string, port int, targetAddr string) error { - if s.proxyManager == nil { - return fmt.Errorf("proxy manager not initialized") - } - return s.proxyManager.AddTarget(proto, listenIP, port, targetAddr) -} - -// RemoveProxyTarget removes a target from the proxy manager -func (s *WireGuardService) RemoveProxyTarget(proto, listenIP string, port int) error { - if s.proxyManager == nil { - return fmt.Errorf("proxy manager not initialized") - } - return s.proxyManager.RemoveTarget(proto, listenIP, port) -} - func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string) (*WireGuardService, error) { var key wgtypes.Key var err error @@ -189,7 +166,6 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str lastReadings: make(map[string]PeerReading), Port: port, dns: dnsAddrs, - proxyManager: proxy.NewProxyManagerWithoutTNet(), sharedBind: sharedBind, } @@ -202,10 +178,6 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer) wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer) wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer) - wsClient.RegisterHandler("newt/wg/tcp/add", service.addTcpTarget) - wsClient.RegisterHandler("newt/wg/udp/add", service.addUdpTarget) - wsClient.RegisterHandler("newt/wg/udp/remove", service.removeUdpTarget) - wsClient.RegisterHandler("newt/wg/tcp/remove", service.removeTcpTarget) return service, nil } @@ -218,86 +190,6 @@ func (s *WireGuardService) ReportRTT(seconds float64) { telemetry.ObserveTunnelLatency(context.Background(), s.serverPubKey, "wireguard", seconds) } -func (s *WireGuardService) addTcpTarget(msg websocket.WSMessage) { - logger.Debug("Received: %+v", msg) - - // if there is no wgData or pm, we can't add targets - if s.TunnelIP == "" || s.proxyManager == nil { - logger.Info("No tunnel IP or proxy manager available") - return - } - - targetData, err := parseTargetData(msg.Data) - if err != nil { - logger.Info("Error parsing target data: %v", err) - return - } - - if len(targetData.Targets) > 0 { - s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", targetData) - } -} - -func (s *WireGuardService) addUdpTarget(msg websocket.WSMessage) { - logger.Info("Received: %+v", msg) - - // if there is no wgData or pm, we can't add targets - if s.TunnelIP == "" || s.proxyManager == nil { - logger.Info("No tunnel IP or proxy manager available") - return - } - - targetData, err := parseTargetData(msg.Data) - if err != nil { - logger.Info("Error parsing target data: %v", err) - return - } - - if len(targetData.Targets) > 0 { - s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", targetData) - } -} - -func (s *WireGuardService) removeUdpTarget(msg websocket.WSMessage) { - logger.Info("Received: %+v", msg) - - // if there is no wgData or pm, we can't add targets - if s.TunnelIP == "" || s.proxyManager == nil { - logger.Info("No tunnel IP or proxy manager available") - return - } - - targetData, err := parseTargetData(msg.Data) - if err != nil { - logger.Info("Error parsing target data: %v", err) - return - } - - if len(targetData.Targets) > 0 { - s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "udp", targetData) - } -} - -func (s *WireGuardService) removeTcpTarget(msg websocket.WSMessage) { - logger.Info("Received: %+v", msg) - - // if there is no wgData or pm, we can't add targets - if s.TunnelIP == "" || s.proxyManager == nil { - logger.Info("No tunnel IP or proxy manager available") - return - } - - targetData, err := parseTargetData(msg.Data) - if err != nil { - logger.Info("Error parsing target data: %v", err) - return - } - - if len(targetData.Targets) > 0 { - s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "tcp", targetData) - } -} - func (s *WireGuardService) SetOthertnet(tnet *netstack.Net) { s.othertnet = tnet } @@ -435,18 +327,6 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { if err := s.ensureWireguardPeers(config.Peers); err != nil { logger.Error("Failed to ensure WireGuard peers: %v", err) } - - // add the targets if there are any - if len(config.Targets.TCP) > 0 { - s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", TargetData{Targets: config.Targets.TCP}) - } - - if len(config.Targets.UDP) > 0 { - s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", TargetData{Targets: config.Targets.UDP}) - } - - // Create ProxyManager for this tunnel - s.proxyManager.Start() } func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { @@ -484,7 +364,7 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { s.mu.Unlock() return fmt.Errorf("failed to create TUN device: %v", err) } - // s.proxyManager.SetTNet(s.tnet) + s.TunnelIP = tunnelIP.String() // Create WireGuard device using the shared bind @@ -921,169 +801,6 @@ func (s *WireGuardService) reportPeerBandwidth() error { return nil } -func (s *WireGuardService) updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { - var replace = false - for _, t := range targetData.Targets { - // Split the first number off of the target with : separator and use as the port - parts := strings.Split(t, ":") - if len(parts) != 3 { - logger.Info("Invalid target format: %s", t) - continue - } - - // Get the port as an int - port := 0 - _, err := fmt.Sscanf(parts[0], "%d", &port) - if err != nil { - logger.Info("Invalid port: %s", parts[0]) - continue - } - - if action == "add" { - target := parts[1] + ":" + parts[2] - - // Call updown script if provided - processedTarget := target - - // Only remove the specific target if it exists - err := pm.RemoveTarget(proto, tunnelIP, port) - if err != nil { - // Ignore "target not found" errors as this is expected for new targets - if !strings.Contains(err.Error(), "target not found") { - logger.Error("Failed to remove existing target: %v", err) - } - } else { - replace = true // We successfully removed an existing target - } - - // Add the new target - pm.AddTarget(proto, tunnelIP, port, processedTarget) - - } else if action == "remove" { - logger.Info("Removing target with port %d", port) - - err := pm.RemoveTarget(proto, tunnelIP, port) - if err != nil { - logger.Error("Failed to remove target: %v", err) - return err - } - } - } - - if replace { - // If we replaced any targets, we need to hot swap the netstack - if err := s.ReplaceNetstack(); err != nil { - logger.Error("Failed to replace netstack after updating targets: %v", err) - return err - } - logger.Info("Netstack replaced successfully after updating targets") - } else { - logger.Info("No targets updated, no netstack replacement needed") - } - - return nil -} - -func parseTargetData(data interface{}) (TargetData, error) { - var targetData TargetData - jsonData, err := json.Marshal(data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return targetData, err - } - - if err := json.Unmarshal(jsonData, &targetData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return targetData, err - } - return targetData, nil -} - -// Add this method to WireGuardService -func (s *WireGuardService) ReplaceNetstack() error { - s.mu.Lock() - defer s.mu.Unlock() - - if s.device == nil || s.tun == nil { - return fmt.Errorf("WireGuard device not initialized") - } - - // Parse the current tunnel IP from the existing config - parts := strings.Split(s.config.IpAddress, "/") - if len(parts) != 2 { - return fmt.Errorf("invalid IP address format: %s", s.config.IpAddress) - } - tunnelIP := netip.MustParseAddr(parts[0]) - - // Stop the proxy manager temporarily - s.proxyManager.Stop() - - // Create new TUN device and netstack with new DNS - newTun, newTnet, err := netstack2.CreateNetTUN( - []netip.Addr{tunnelIP}, - s.dns, - s.mtu) - if err != nil { - // Restart proxy manager with old tnet on failure - s.proxyManager.Start() - return fmt.Errorf("failed to create new TUN device: %v", err) - } - - // Get current device config before closing - currentConfig, err := s.device.IpcGet() - if err != nil { - newTun.Close() - s.proxyManager.Start() - return fmt.Errorf("failed to get current device config: %v", err) - } - - // Filter out read-only fields from the config - filteredConfig := s.filterReadOnlyFields(currentConfig) - - // if onNetstackClose callback is set, call it - if s.onNetstackClose != nil { - s.onNetstackClose() - } - - // Close old device (this closes the old TUN device) - s.device.Close() - - // Update references - s.tun = newTun - s.tnet = newTnet - - // Create new WireGuard device with same shared bind - s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger( - device.LogLevelSilent, - "wireguard: ", - )) - - // Restore the configuration (without read-only fields) - err = s.device.IpcSet(filteredConfig) - if err != nil { - return fmt.Errorf("failed to restore WireGuard configuration: %v", err) - } - - // Bring up the device - err = s.device.Up() - if err != nil { - return fmt.Errorf("failed to bring up new WireGuard device: %v", err) - } - - // Update proxy manager with new tnet and restart - // s.proxyManager.SetTNet(s.tnet) - s.proxyManager.Start() - - s.proxyManager.PrintTargets() - - // Call the netstack ready callback if set - if s.onNetstackReady != nil { - go s.onNetstackReady(s.tnet) - } - - return nil -} - // filterReadOnlyFields removes read-only fields from WireGuard IPC configuration func (s *WireGuardService) filterReadOnlyFields(config string) string { lines := strings.Split(config, "\n")