Files
netbird/client/inspect/rules_test.go

339 lines
9.1 KiB
Go

package inspect
import (
"net/netip"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
)
func testLogger() *log.Entry {
return log.WithField("test", true)
}
func mustDomain(t *testing.T, s string) domain.Domain {
t.Helper()
d, err := domain.FromString(s)
require.NoError(t, err)
return d
}
func TestRuleEngine_Evaluate(t *testing.T) {
tests := []struct {
name string
rules []Rule
defaultAction Action
src netip.Addr
dstDomain domain.Domain
dstAddr netip.Addr
dstPort uint16
want Action
}{
{
name: "no rules returns default allow",
defaultAction: ActionAllow,
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
{
name: "no rules returns default block",
defaultAction: ActionBlock,
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "domain exact match blocks",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "malware.example.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "malware.example.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "domain wildcard match blocks",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "phishing.evil.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "domain wildcard does not match base",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "evil.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
{
name: "case insensitive domain match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "Example.COM")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "EXAMPLE.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "source CIDR match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
Action: ActionInspect,
},
},
src: netip.MustParseAddr("192.168.1.50"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionInspect,
},
{
name: "source CIDR no match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.5"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
{
name: "destination network match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Networks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
Action: ActionInspect,
},
},
src: netip.MustParseAddr("192.168.1.1"),
dstAddr: netip.MustParseAddr("10.50.0.1"),
dstPort: 80,
want: ActionInspect,
},
{
name: "port match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Ports: []uint16{443, 8443},
Action: ActionInspect,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionInspect,
},
{
name: "port no match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Ports: []uint16{443, 8443},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 22,
want: ActionAllow,
},
{
name: "priority ordering first match wins",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("allow-internal"),
Domains: []domain.Domain{mustDomain(t, "*.internal.corp")},
Action: ActionAllow,
Priority: 1,
},
{
ID: id.RuleID("inspect-all"),
Action: ActionInspect,
Priority: 10,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "api.internal.corp"),
dstAddr: netip.MustParseAddr("10.1.0.5"),
dstPort: 443,
want: ActionAllow,
},
{
name: "all fields must match (AND logic)",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
Ports: []uint16{443},
Action: ActionBlock,
},
},
// Source matches, domain matches, but port doesn't
src: netip.MustParseAddr("192.168.1.10"),
dstDomain: mustDomain(t, "phish.evil.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 8080,
want: ActionAllow,
},
{
name: "empty domain with domain rule does not match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "example.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: "", // raw IP connection, no SNI
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine := NewRuleEngine(testLogger(), tt.defaultAction)
engine.UpdateRules(tt.rules, tt.defaultAction)
got := engine.Evaluate(tt.src, tt.dstDomain, tt.dstAddr, tt.dstPort, "", "")
assert.Equal(t, tt.want, got)
})
}
}
func TestRuleEngine_ProtocolMatching(t *testing.T) {
engine := NewRuleEngine(testLogger(), ActionAllow)
engine.UpdateRules([]Rule{
{
ID: "block-websocket",
Protocols: []ProtoType{ProtoWebSocket},
Action: ActionBlock,
Priority: 1,
},
{
ID: "inspect-h2",
Protocols: []ProtoType{ProtoH2},
Action: ActionInspect,
Priority: 2,
},
}, ActionAllow)
src := netip.MustParseAddr("10.0.0.1")
dst := netip.MustParseAddr("1.2.3.4")
// WebSocket: blocked by rule
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoWebSocket, ""))
// HTTP/2: inspected by rule
assert.Equal(t, ActionInspect, engine.Evaluate(src, "", dst, 443, ProtoH2, ""))
// Plain HTTP: no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 80, ProtoHTTP, ""))
// HTTPS: no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, ProtoHTTPS, ""))
// QUIC/H3: no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, ProtoH3, ""))
// Empty protocol (unknown): no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, "", ""))
}
func TestRuleEngine_EmptyProtocolsMatchAll(t *testing.T) {
engine := NewRuleEngine(testLogger(), ActionAllow)
engine.UpdateRules([]Rule{
{
ID: "block-all-protos",
Action: ActionBlock,
// No Protocols field = match all protocols
Priority: 1,
},
}, ActionAllow)
src := netip.MustParseAddr("10.0.0.1")
dst := netip.MustParseAddr("1.2.3.4")
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoHTTP, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoHTTPS, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoWebSocket, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoH2, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, "", ""))
}
func TestRuleEngine_UpdateRulesSortsByPriority(t *testing.T) {
engine := NewRuleEngine(testLogger(), ActionAllow)
engine.UpdateRules([]Rule{
{ID: "c", Priority: 30, Action: ActionBlock},
{ID: "a", Priority: 10, Action: ActionInspect},
{ID: "b", Priority: 20, Action: ActionAllow},
}, ActionAllow)
engine.mu.RLock()
defer engine.mu.RUnlock()
require.Len(t, engine.rules, 3)
assert.Equal(t, id.RuleID("a"), engine.rules[0].ID)
assert.Equal(t, id.RuleID("b"), engine.rules[1].ID)
assert.Equal(t, id.RuleID("c"), engine.rules[2].ID)
}