mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 00:06:38 +00:00
Extend protocol and firewall manager to handle old management (#915)
* Extend protocol and firewall manager to handle old management * Send correct empty firewall rules list when delete peer * Add extra tests for firewall manager and uspfilter * Work with inconsistent state * Review note * Update comment
This commit is contained in:
committed by
GitHub
parent
45a6263adc
commit
293499c3c0
@@ -236,11 +236,20 @@ func (m *Manager) filterRuleSpecs(
|
||||
table string, ip net.IP, protocol string, sPort, dPort string,
|
||||
direction fw.RuleDirection, action fw.Action, comment string,
|
||||
) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if s := ip.String(); s == "0.0.0.0" || s == "::" {
|
||||
matchByIP = false
|
||||
}
|
||||
switch direction {
|
||||
case fw.RuleDirectionIN:
|
||||
specs = append(specs, "-s", ip.String())
|
||||
if matchByIP {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
case fw.RuleDirectionOUT:
|
||||
specs = append(specs, "-d", ip.String())
|
||||
if matchByIP {
|
||||
specs = append(specs, "-d", ip.String())
|
||||
}
|
||||
}
|
||||
if protocol != "all" {
|
||||
specs = append(specs, "-p", protocol)
|
||||
|
||||
@@ -139,38 +139,41 @@ func (m *Manager) AddFiltering(
|
||||
})
|
||||
}
|
||||
|
||||
// source address position
|
||||
var adrLen, adrOffset uint32
|
||||
if ip.To4() == nil {
|
||||
adrLen = 16
|
||||
adrOffset = 8
|
||||
} else {
|
||||
adrLen = 4
|
||||
adrOffset = 12
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if s := ip.String(); s != "0.0.0.0" && s != "::" {
|
||||
// source address position
|
||||
var adrLen, adrOffset uint32
|
||||
if ip.To4() == nil {
|
||||
adrLen = 16
|
||||
adrOffset = 8
|
||||
} else {
|
||||
adrLen = 4
|
||||
adrOffset = 12
|
||||
}
|
||||
|
||||
// change to destination address position if need
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
adrOffset += adrLen
|
||||
}
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: adrOffset,
|
||||
Len: adrLen,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// change to destination address position if need
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
adrOffset += adrLen
|
||||
}
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: adrOffset,
|
||||
Len: adrLen,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
},
|
||||
)
|
||||
|
||||
if sPort != nil && len(sPort.Values) != 0 {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
|
||||
@@ -13,6 +13,7 @@ type Rule struct {
|
||||
id string
|
||||
ip net.IP
|
||||
ipLayer gopacket.LayerType
|
||||
matchByIP bool
|
||||
protoLayer gopacket.LayerType
|
||||
direction fw.RuleDirection
|
||||
sPort uint16
|
||||
|
||||
@@ -87,6 +87,7 @@ func (m *Manager) AddFiltering(
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
matchByIP: true,
|
||||
direction: direction,
|
||||
drop: action == fw.ActionDrop,
|
||||
comment: comment,
|
||||
@@ -96,6 +97,10 @@ func (m *Manager) AddFiltering(
|
||||
r.ip = ipNormalized
|
||||
}
|
||||
|
||||
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
||||
r.matchByIP = false
|
||||
}
|
||||
|
||||
if sPort != nil && len(sPort.Values) == 1 {
|
||||
r.sPort = uint16(sPort.Values[0])
|
||||
}
|
||||
@@ -223,25 +228,27 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
||||
|
||||
// check if IP address match by IP
|
||||
for _, rule := range rules {
|
||||
switch ipLayer {
|
||||
case layers.LayerTypeIPv4:
|
||||
if isIncomingPacket {
|
||||
if !d.ip4.SrcIP.Equal(rule.ip) {
|
||||
continue
|
||||
if rule.matchByIP {
|
||||
switch ipLayer {
|
||||
case layers.LayerTypeIPv4:
|
||||
if isIncomingPacket {
|
||||
if !d.ip4.SrcIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if !d.ip4.DstIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if !d.ip4.DstIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
if isIncomingPacket {
|
||||
if !d.ip6.SrcIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if !d.ip6.DstIP.Equal(rule.ip) {
|
||||
continue
|
||||
case layers.LayerTypeIPv6:
|
||||
if isIncomingPacket {
|
||||
if !d.ip6.SrcIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if !d.ip6.DstIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
@@ -171,6 +173,72 @@ func TestManagerReset(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotMatchByIP(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilteringFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
}
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
ip := net.ParseIP("0.0.0.0")
|
||||
proto := fw.ProtocolUDP
|
||||
direction := fw.RuleDirectionOUT
|
||||
action := fw.ActionAccept
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ipv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: net.ParseIP("100.10.0.1"),
|
||||
DstIP: net.ParseIP("100.10.0.100"),
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
udp := &layers.UDP{
|
||||
SrcPort: 51334,
|
||||
DstPort: 53,
|
||||
}
|
||||
|
||||
if err := udp.SetNetworkLayerForChecksum(ipv4); err != nil {
|
||||
t.Errorf("failed to set network layer for checksum: %v", err)
|
||||
return
|
||||
}
|
||||
payload := gopacket.Payload([]byte("test"))
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
if err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload); err != nil {
|
||||
t.Errorf("failed to serialize packet: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if m.dropFilter(buf.Bytes(), m.outgoingRules, false) {
|
||||
t.Errorf("expected packet to be accepted")
|
||||
return
|
||||
}
|
||||
|
||||
if err = m.Reset(); err != nil {
|
||||
t.Errorf("failed to reset Manager: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
|
||||
@@ -22,7 +22,7 @@ type iFaceMapper interface {
|
||||
|
||||
// Manager is a ACL rules manager
|
||||
type Manager interface {
|
||||
ApplyFiltering(rules []*mgmProto.FirewallRule)
|
||||
ApplyFiltering(rules []*mgmProto.FirewallRule, allowByDefault bool)
|
||||
Stop()
|
||||
}
|
||||
|
||||
@@ -34,7 +34,9 @@ type DefaultManager struct {
|
||||
}
|
||||
|
||||
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
||||
func (d *DefaultManager) ApplyFiltering(rules []*mgmProto.FirewallRule) {
|
||||
//
|
||||
// If allowByDefault is ture it appends allow ALL traffic rules to input and output chains.
|
||||
func (d *DefaultManager) ApplyFiltering(rules []*mgmProto.FirewallRule, allowByDefault bool) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
|
||||
@@ -47,6 +49,22 @@ func (d *DefaultManager) ApplyFiltering(rules []*mgmProto.FirewallRule) {
|
||||
applyFailed bool
|
||||
newRulePairs = make(map[string][]firewall.Rule)
|
||||
)
|
||||
if allowByDefault {
|
||||
rules = append(rules,
|
||||
&mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: mgmProto.FirewallRule_IN,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
&mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: mgmProto.FirewallRule_OUT,
|
||||
Action: mgmProto.FirewallRule_ACCEPT,
|
||||
Protocol: mgmProto.FirewallRule_ALL,
|
||||
},
|
||||
)
|
||||
}
|
||||
for _, r := range rules {
|
||||
rules, err := d.protoRuleToFirewallRule(r)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
@@ -11,12 +10,6 @@ import (
|
||||
)
|
||||
|
||||
func TestDefaultManager(t *testing.T) {
|
||||
// TODO: enable when other platform will be added
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("ACL manager not supported in the: %s", runtime.GOOS)
|
||||
return
|
||||
}
|
||||
|
||||
fwRules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
@@ -38,8 +31,9 @@ func TestDefaultManager(t *testing.T) {
|
||||
defer ctrl.Finish()
|
||||
|
||||
iface := mocks.NewMockIFaceMapper(ctrl)
|
||||
iface.EXPECT().IsUserspaceBind().Return(false)
|
||||
iface.EXPECT().Name().Return("lo")
|
||||
iface.EXPECT().IsUserspaceBind().Return(true)
|
||||
// iface.EXPECT().Name().Return("lo")
|
||||
iface.EXPECT().SetFiltering(gomock.Any())
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
acl, err := Create(iface)
|
||||
@@ -50,7 +44,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
defer acl.Stop()
|
||||
|
||||
t.Run("apply firewall rules", func(t *testing.T) {
|
||||
acl.ApplyFiltering(fwRules)
|
||||
acl.ApplyFiltering(fwRules, false)
|
||||
|
||||
if len(acl.rulesPairs) != 2 {
|
||||
t.Errorf("firewall rules not applied: %v", acl.rulesPairs)
|
||||
@@ -73,7 +67,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
existedRulesID[id] = struct{}{}
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(fwRules)
|
||||
acl.ApplyFiltering(fwRules, false)
|
||||
|
||||
// we should have one old and one new rule in the existed rules
|
||||
if len(acl.rulesPairs) != 2 {
|
||||
@@ -89,4 +83,18 @@ func TestDefaultManager(t *testing.T) {
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handle default rules", func(t *testing.T) {
|
||||
acl.ApplyFiltering(nil, false)
|
||||
if len(acl.rulesPairs) != 0 {
|
||||
t.Errorf("rules should be empty if default not allowed, got: %v rules", len(acl.rulesPairs))
|
||||
return
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(nil, true)
|
||||
if len(acl.rulesPairs) != 2 {
|
||||
t.Errorf("two default rules should be set if default is allowed, got: %v rules", len(acl.rulesPairs))
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -637,7 +637,13 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
|
||||
if e.acl != nil {
|
||||
e.acl.ApplyFiltering(networkMap.FirewallRules)
|
||||
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
||||
// we have old version of management without rules handling, we should allow all traffic
|
||||
allowByDefault := len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty
|
||||
if allowByDefault {
|
||||
log.Warn("this peer is connected to a NetBird Management service with an older version. Allowing all traffic from connected peers")
|
||||
}
|
||||
e.acl.ApplyFiltering(networkMap.FirewallRules, allowByDefault)
|
||||
}
|
||||
e.networkSerial = serial
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user