diff --git a/clients/clients.go b/clients/clients.go index c7fa9fc..a945985 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -46,6 +46,7 @@ type Target struct { type PortRange struct { Min uint16 `json:"min"` Max uint16 `json:"max"` + Protocol string `json:"protocol"` // "tcp" or "udp" } type Peer struct { @@ -702,6 +703,7 @@ func (s *WireGuardService) ensureTargets(targets []Target) error { portRanges = append(portRanges, netstack2.PortRange{ Min: pr.Min, Max: pr.Max, + Protocol: pr.Protocol, }) } diff --git a/netstack2/proxy.go b/netstack2/proxy.go index ced91f9..90ec15e 100644 --- a/netstack2/proxy.go +++ b/netstack2/proxy.go @@ -22,10 +22,12 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) -// PortRange represents an allowed range of ports (inclusive) +// PortRange represents an allowed range of ports (inclusive) with optional protocol filtering +// Protocol can be "tcp", "udp", or "" (empty string means both protocols) type PortRange struct { - Min uint16 - Max uint16 + Min uint16 + Max uint16 + Protocol string // "tcp", "udp", or "" for both } // SubnetRule represents a subnet with optional port restrictions and source address @@ -98,14 +100,16 @@ func (sl *SubnetLookup) RemoveSubnet(sourcePrefix, destPrefix netip.Prefix) { delete(sl.rules, key) } -// Match checks if a source IP, destination IP, and port match any subnet rule -// Returns the matched rule if BOTH: +// Match checks if a source IP, destination IP, port, and protocol match any subnet rule +// Returns the matched rule if ALL of these conditions are met: // - The source IP is in the rule's source prefix // - The destination IP is in the rule's destination prefix // - The port is in an allowed range (or no port restrictions exist) +// - The protocol matches (or the port range allows both protocols) // +// proto should be header.TCPProtocolNumber or header.UDPProtocolNumber // Returns nil if no rule matches -func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) *SubnetRule { +func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16, proto tcpip.TransportProtocolNumber) *SubnetRule { sl.mu.RLock() defer sl.mu.RUnlock() @@ -126,10 +130,20 @@ func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) *SubnetRule return rule } - // Check if port is in any of the allowed ranges + // Check if port and protocol are in any of the allowed ranges for _, pr := range rule.PortRanges { if port >= pr.Min && port <= pr.Max { - return rule + // Check protocol compatibility + if pr.Protocol == "" { + // Empty protocol means allow both TCP and UDP + return rule + } + // Check if the packet protocol matches the port range protocol + if (pr.Protocol == "tcp" && proto == header.TCPProtocolNumber) || + (pr.Protocol == "udp" && proto == header.UDPProtocolNumber) { + return rule + } + // Port matches but protocol doesn't - continue checking other ranges } } } @@ -435,8 +449,8 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { logger.Debug("HandleIncomingPacket: Unknown protocol %d from %s to %s", protocol, srcAddr, dstAddr) } - // Check if the source IP, destination IP, and port match any subnet rule - matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort) + // Check if the source IP, destination IP, port, and protocol match any subnet rule + matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol) if matchedRule != nil { logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d)", srcAddr, dstAddr, protocol, dstPort) diff --git a/udp_client.py b/udp_client.py new file mode 100644 index 0000000..2909d13 --- /dev/null +++ b/udp_client.py @@ -0,0 +1,49 @@ +import socket +import sys + +# Argument parsing: Check if IP and Port are provided +if len(sys.argv) != 3: + print("Usage: python udp_client.py ") + # Example: python udp_client.py 127.0.0.1 12000 + sys.exit(1) + +HOST = sys.argv[1] +try: + PORT = int(sys.argv[2]) +except ValueError: + print("Error: HOST_PORT must be an integer.") + sys.exit(1) + +# The message to send to the server +MESSAGE = "Hello UDP Server! How are you?" + +# Create a UDP socket +try: + client_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) +except socket.error as err: + print(f"Failed to create socket: {err}") + sys.exit() + +try: + print(f"Sending message to {HOST}:{PORT}...") + + # Send the message (data must be encoded to bytes) + client_socket.sendto(MESSAGE.encode('utf-8'), (HOST, PORT)) + + # Wait for the server's response (buffer size 1024 bytes) + data, server_address = client_socket.recvfrom(1024) + + # Decode and print the server's response + response = data.decode('utf-8') + print("-" * 30) + print(f"Received response from server {server_address[0]}:{server_address[1]}:") + print(f"-> Data: '{response}'") + +except socket.error as err: + print(f"Error during communication: {err}") + +finally: + # Close the socket + client_socket.close() + print("-" * 30) + print("Client finished and socket closed.") diff --git a/udp_server.py b/udp_server.py new file mode 100644 index 0000000..89aea28 --- /dev/null +++ b/udp_server.py @@ -0,0 +1,58 @@ +import socket +import sys + +# optionally take in some positional args for the port +if len(sys.argv) > 1: + try: + PORT = int(sys.argv[1]) + except ValueError: + print("Invalid port number. Using default port 12000.") + PORT = 12000 +else: + PORT = 12000 + +# Define the server host and port +HOST = '0.0.0.0' # Standard loopback interface address (localhost) + +# Create a UDP socket +try: + server_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) +except socket.error as err: + print(f"Failed to create socket: {err}") + sys.exit() + +# Bind the socket to the address +try: + server_socket.bind((HOST, PORT)) + print(f"UDP Server listening on {HOST}:{PORT}") +except socket.error as err: + print(f"Bind failed: {err}") + server_socket.close() + sys.exit() + +# Wait for and process incoming data +while True: + try: + # Receive data and the client's address (buffer size 1024 bytes) + data, client_address = server_socket.recvfrom(1024) + + # Decode the data and print the message + message = data.decode('utf-8') + print("-" * 30) + print(f"Received message from {client_address[0]}:{client_address[1]}:") + print(f"-> Data: '{message}'") + + # Prepare the response message + response_message = f"Hello client! Server received: '{message.upper()}'" + + # Send the response back to the client + server_socket.sendto(response_message.encode('utf-8'), client_address) + print(f"Sent response back to client.") + + except Exception as e: + print(f"An error occurred: {e}") + break + +# Clean up (though usually unreachable in an infinite server loop) +server_socket.close() +print("Server stopped.")