diff --git a/clients/clients.go b/clients/clients.go index 9b17d07..7bc2669 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -46,6 +46,7 @@ type Target struct { type PortRange struct { Min uint16 `json:"min"` Max uint16 `json:"max"` + Protocol string `json:"protocol"` // "tcp" or "udp" } type Peer struct { @@ -701,6 +702,7 @@ func (s *WireGuardService) ensureTargets(targets []Target) error { portRanges = append(portRanges, netstack2.PortRange{ Min: pr.Min, Max: pr.Max, + Protocol: pr.Protocol, }) } diff --git a/netstack2/proxy.go b/netstack2/proxy.go index 77a9d23..3338cd0 100644 --- a/netstack2/proxy.go +++ b/netstack2/proxy.go @@ -22,10 +22,12 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) -// PortRange represents an allowed range of ports (inclusive) +// PortRange represents an allowed range of ports (inclusive) with optional protocol filtering +// Protocol can be "tcp", "udp", or "" (empty string means both protocols) type PortRange struct { - Min uint16 - Max uint16 + Min uint16 + Max uint16 + Protocol string // "tcp", "udp", or "" for both } // SubnetRule represents a subnet with optional port restrictions and source address @@ -97,14 +99,16 @@ func (sl *SubnetLookup) RemoveSubnet(sourcePrefix, destPrefix netip.Prefix) { delete(sl.rules, key) } -// Match checks if a source IP, destination IP, and port match any subnet rule -// Returns the matched rule if BOTH: +// Match checks if a source IP, destination IP, port, and protocol match any subnet rule +// Returns the matched rule if ALL of these conditions are met: // - The source IP is in the rule's source prefix // - The destination IP is in the rule's destination prefix // - The port is in an allowed range (or no port restrictions exist) +// - The protocol matches (or the port range allows both protocols) // +// proto should be header.TCPProtocolNumber or header.UDPProtocolNumber // Returns nil if no rule matches -func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) *SubnetRule { +func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16, proto tcpip.TransportProtocolNumber) *SubnetRule { sl.mu.RLock() defer sl.mu.RUnlock() @@ -125,10 +129,20 @@ func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) *SubnetRule return rule } - // Check if port is in any of the allowed ranges + // Check if port and protocol are in any of the allowed ranges for _, pr := range rule.PortRanges { if port >= pr.Min && port <= pr.Max { - return rule + // Check protocol compatibility + if pr.Protocol == "" { + // Empty protocol means allow both TCP and UDP + return rule + } + // Check if the packet protocol matches the port range protocol + if (pr.Protocol == "tcp" && proto == header.TCPProtocolNumber) || + (pr.Protocol == "udp" && proto == header.UDPProtocolNumber) { + return rule + } + // Port matches but protocol doesn't - continue checking other ranges } } } @@ -412,8 +426,8 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { dstPort = 0 } - // Check if the source IP, destination IP, and port match any subnet rule - matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort) + // Check if the source IP, destination IP, port, and protocol match any subnet rule + matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol) if matchedRule != nil { // Check if we need to perform DNAT if matchedRule.RewriteTo != "" {