diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index d067a3e7b..84d8cb0d9 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -135,7 +135,16 @@ func (r *router) AddRouteFiltering( } rule := genRouteFilteringRuleSpec(params) - if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { + // Insert DROP rules at the beginning, append ACCEPT rules at the end + var err error + if action == firewall.ActionDrop { + // after the established rule + err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...) + } else { + err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...) + } + + if err != nil { return nil, fmt.Errorf("add route rule: %v", err) } diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 9c9637282..b519d55ba 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -107,7 +107,7 @@ func TestNftablesManager(t *testing.T) { Kind: expr.VerdictAccept, }, } - require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions") + compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1) ipToAdd, _ := netip.AddrFromSlice(ip) add := ipToAdd.Unmap() @@ -307,3 +307,18 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { stdout, stderr = runIptablesSave(t) verifyIptablesOutput(t, stdout, stderr) } + +func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) { + t.Helper() + require.Equal(t, len(got), len(want), "expression count mismatch") + + for i := range got { + if _, isCounter := got[i].(*expr.Counter); isCounter { + _, wantIsCounter := want[i].(*expr.Counter) + require.True(t, wantIsCounter, "expected Counter at index %d", i) + continue + } + + require.Equal(t, got[i], want[i], "expression mismatch at index %d", i) + } +} diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 34bc9a9bc..5a02e2895 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -233,7 +233,13 @@ func (r *router) AddRouteFiltering( UserData: []byte(ruleKey), } - rule = r.conn.AddRule(rule) + // Insert DROP rules at the beginning, append ACCEPT rules at the end + if action == firewall.ActionDrop { + // TODO: Insert after the established rule + rule = r.conn.InsertRule(rule) + } else { + rule = r.conn.AddRule(rule) + } log.Tracef("Adding route rule %s", spew.Sdump(rule)) if err := r.conn.Flush(); err != nil { diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 99a3dcee0..fa3068642 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -6,7 +6,9 @@ import ( "net" "net/netip" "os" + "slices" "strconv" + "strings" "sync" "github.com/google/gopacket" @@ -43,13 +45,28 @@ const ( // RuleSet is a set of rules grouped by a string key type RuleSet map[string]PeerRule +type RouteRules []RouteRule + +func (r RouteRules) Sort() { + slices.SortStableFunc(r, func(a, b RouteRule) int { + // Deny rules come first + if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop { + return -1 + } + if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop { + return 1 + } + return strings.Compare(a.id, b.id) + }) +} + // Manager userspace firewall manager type Manager struct { // outgoingRules is used for hooks only outgoingRules map[string]RuleSet // incomingRules is used for filtering and hooks incomingRules map[string]RuleSet - routeRules map[string]RouteRule + routeRules RouteRules wgNetwork *net.IPNet decoders sync.Pool wgIface common.IFaceMapper @@ -135,7 +152,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe nativeFirewall: nativeFirewall, outgoingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet), - routeRules: make(map[string]RouteRule), wgIface: iface, localipmanager: newLocalIPManager(), routingEnabled: false, @@ -377,7 +393,8 @@ func (m *Manager) AddRouteFiltering( action: action, } - m.routeRules[ruleID] = rule + m.routeRules = append(m.routeRules, rule) + m.routeRules.Sort() return &rule, nil } @@ -391,11 +408,14 @@ func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { defer m.mutex.Unlock() ruleID := rule.GetRuleID() - if _, exists := m.routeRules[ruleID]; !exists { + idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool { + return r.id == ruleID + }) + if idx < 0 { return fmt.Errorf("route rule not found: %s", ruleID) } - delete(m.routeRules, ruleID) + m.routeRules = slices.Delete(m.routeRules, idx, idx+1) return nil } diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/uspfilter_filter_test.go index 73209a152..ef1a0bed3 100644 --- a/client/firewall/uspfilter/uspfilter_filter_test.go +++ b/client/firewall/uspfilter/uspfilter_filter_test.go @@ -713,6 +713,56 @@ func TestRouteACLFiltering(t *testing.T) { }, shouldPass: true, }, + { + name: "Drop TCP traffic to specific destination", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []int{443}}, + action: fw.ActionDrop, + }, + shouldPass: false, + }, + { + name: "Drop all traffic to specific destination", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolALL, + action: fw.ActionDrop, + }, + shouldPass: false, + }, + { + name: "Drop traffic from multiple source networks", + srcIP: "172.16.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{ + netip.MustParsePrefix("100.10.0.0/16"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []int{80}}, + action: fw.ActionDrop, + }, + shouldPass: false, + }, } for _, tc := range testCases { @@ -742,3 +792,190 @@ func TestRouteACLFiltering(t *testing.T) { }) } } + + +func TestRouteACLOrder(t *testing.T) { + manager := setupRoutedManager(t, "10.10.0.100/16") + + type testCase struct { + name string + rules []struct { + sources []netip.Prefix + dest netip.Prefix + proto fw.Protocol + srcPort *fw.Port + dstPort *fw.Port + action fw.Action + } + packets []struct { + srcIP string + dstIP string + proto fw.Protocol + srcPort uint16 + dstPort uint16 + shouldPass bool + } + } + + testCases := []testCase{ + { + name: "Drop rules take precedence over accept", + rules: []struct { + sources []netip.Prefix + dest netip.Prefix + proto fw.Protocol + srcPort *fw.Port + dstPort *fw.Port + action fw.Action + }{ + { + // Accept rule added first + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []int{80, 443}}, + action: fw.ActionAccept, + }, + { + // Drop rule added second but should be evaluated first + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []int{443}}, + action: fw.ActionDrop, + }, + }, + packets: []struct { + srcIP string + dstIP string + proto fw.Protocol + srcPort uint16 + dstPort uint16 + shouldPass bool + }{ + { + // Should be dropped by the drop rule + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + shouldPass: false, + }, + { + // Should be allowed by the accept rule (port 80 not in drop rule) + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + shouldPass: true, + }, + }, + }, + { + name: "Multiple drop rules take precedence", + rules: []struct { + sources []netip.Prefix + dest netip.Prefix + proto fw.Protocol + srcPort *fw.Port + dstPort *fw.Port + action fw.Action + }{ + { + // Accept all + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + dest: netip.MustParsePrefix("0.0.0.0/0"), + proto: fw.ProtocolALL, + action: fw.ActionAccept, + }, + { + // Drop specific port + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []int{443}}, + action: fw.ActionDrop, + }, + { + // Drop different port + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []int{80}}, + action: fw.ActionDrop, + }, + }, + packets: []struct { + srcIP string + dstIP string + proto fw.Protocol + srcPort uint16 + dstPort uint16 + shouldPass bool + }{ + { + // Should be dropped by first drop rule + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + shouldPass: false, + }, + { + // Should be dropped by second drop rule + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + shouldPass: false, + }, + { + // Should be allowed by the accept rule (different port) + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + shouldPass: true, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var rules []fw.Rule + for _, r := range tc.rules { + rule, err := manager.AddRouteFiltering( + r.sources, + r.dest, + r.proto, + r.srcPort, + r.dstPort, + r.action, + ) + require.NoError(t, err) + require.NotNil(t, rule) + rules = append(rules, rule) + } + + t.Cleanup(func() { + for _, rule := range rules { + require.NoError(t, manager.DeleteRouteRule(rule)) + } + }) + + for i, p := range tc.packets { + srcIP := net.ParseIP(p.srcIP) + dstIP := net.ParseIP(p.dstIP) + + isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort) + require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i) + } + }) + } +}