Extend protocol and firewall manager to handle old management (#915)

* Extend protocol and firewall manager to handle old management

* Send correct empty firewall rules list when delete peer

* Add extra tests for firewall manager and uspfilter

* Work with inconsistent state

* Review note

* Update comment
This commit is contained in:
Givi Khojanashvili
2023-05-31 21:04:38 +04:00
committed by GitHub
parent 45a6263adc
commit 293499c3c0
13 changed files with 362 additions and 220 deletions

View File

@@ -236,11 +236,20 @@ func (m *Manager) filterRuleSpecs(
table string, ip net.IP, protocol string, sPort, dPort string,
direction fw.RuleDirection, action fw.Action, comment string,
) (specs []string) {
matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0
if s := ip.String(); s == "0.0.0.0" || s == "::" {
matchByIP = false
}
switch direction {
case fw.RuleDirectionIN:
specs = append(specs, "-s", ip.String())
if matchByIP {
specs = append(specs, "-s", ip.String())
}
case fw.RuleDirectionOUT:
specs = append(specs, "-d", ip.String())
if matchByIP {
specs = append(specs, "-d", ip.String())
}
}
if protocol != "all" {
specs = append(specs, "-p", protocol)

View File

@@ -139,38 +139,41 @@ func (m *Manager) AddFiltering(
})
}
// source address position
var adrLen, adrOffset uint32
if ip.To4() == nil {
adrLen = 16
adrOffset = 8
} else {
adrLen = 4
adrOffset = 12
// don't use IP matching if IP is ip 0.0.0.0
if s := ip.String(); s != "0.0.0.0" && s != "::" {
// source address position
var adrLen, adrOffset uint32
if ip.To4() == nil {
adrLen = 16
adrOffset = 8
} else {
adrLen = 4
adrOffset = 12
}
// change to destination address position if need
if direction == fw.RuleDirectionOUT {
adrOffset += adrLen
}
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: adrOffset,
Len: adrLen,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: add.AsSlice(),
},
)
}
// change to destination address position if need
if direction == fw.RuleDirectionOUT {
adrOffset += adrLen
}
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: adrOffset,
Len: adrLen,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: add.AsSlice(),
},
)
if sPort != nil && len(sPort.Values) != 0 {
expressions = append(expressions,
&expr.Payload{

View File

@@ -13,6 +13,7 @@ type Rule struct {
id string
ip net.IP
ipLayer gopacket.LayerType
matchByIP bool
protoLayer gopacket.LayerType
direction fw.RuleDirection
sPort uint16

View File

@@ -87,6 +87,7 @@ func (m *Manager) AddFiltering(
id: uuid.New().String(),
ip: ip,
ipLayer: layers.LayerTypeIPv6,
matchByIP: true,
direction: direction,
drop: action == fw.ActionDrop,
comment: comment,
@@ -96,6 +97,10 @@ func (m *Manager) AddFiltering(
r.ip = ipNormalized
}
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
r.matchByIP = false
}
if sPort != nil && len(sPort.Values) == 1 {
r.sPort = uint16(sPort.Values[0])
}
@@ -223,25 +228,27 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
// check if IP address match by IP
for _, rule := range rules {
switch ipLayer {
case layers.LayerTypeIPv4:
if isIncomingPacket {
if !d.ip4.SrcIP.Equal(rule.ip) {
continue
if rule.matchByIP {
switch ipLayer {
case layers.LayerTypeIPv4:
if isIncomingPacket {
if !d.ip4.SrcIP.Equal(rule.ip) {
continue
}
} else {
if !d.ip4.DstIP.Equal(rule.ip) {
continue
}
}
} else {
if !d.ip4.DstIP.Equal(rule.ip) {
continue
}
}
case layers.LayerTypeIPv6:
if isIncomingPacket {
if !d.ip6.SrcIP.Equal(rule.ip) {
continue
}
} else {
if !d.ip6.DstIP.Equal(rule.ip) {
continue
case layers.LayerTypeIPv6:
if isIncomingPacket {
if !d.ip6.SrcIP.Equal(rule.ip) {
continue
}
} else {
if !d.ip6.DstIP.Equal(rule.ip) {
continue
}
}
}
}

View File

@@ -6,6 +6,8 @@ import (
"testing"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall"
@@ -171,6 +173,72 @@ func TestManagerReset(t *testing.T) {
}
}
func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilteringFunc: func(iface.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
}
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP
direction := fw.RuleDirectionOUT
action := fw.ActionAccept
comment := "Test rule"
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: net.ParseIP("100.10.0.1"),
DstIP: net.ParseIP("100.10.0.100"),
Protocol: layers.IPProtocolUDP,
}
udp := &layers.UDP{
SrcPort: 51334,
DstPort: 53,
}
if err := udp.SetNetworkLayerForChecksum(ipv4); err != nil {
t.Errorf("failed to set network layer for checksum: %v", err)
return
}
payload := gopacket.Payload([]byte("test"))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
if err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload); err != nil {
t.Errorf("failed to serialize packet: %v", err)
return
}
if m.dropFilter(buf.Bytes(), m.outgoingRules, false) {
t.Errorf("expected packet to be accepted")
return
}
if err = m.Reset(); err != nil {
t.Errorf("failed to reset Manager: %v", err)
return
}
}
func TestUSPFilterCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {