Refactor protocol handling for firewall rules, add engine tests

This commit is contained in:
Givi Khojanashvili
2023-03-21 17:56:47 +04:00
parent 0abd05d51e
commit 64ad771099
10 changed files with 263 additions and 112 deletions

View File

@@ -44,6 +44,7 @@ type Manager interface {
// rule ID as comment for the rule
AddFiltering(
ip net.IP,
proto Protocol,
port *Port,
direction Direction,
action Action,

View File

@@ -54,6 +54,7 @@ func Create() (*Manager, error) {
// If comment is empty rule ID is used as comment
func (m *Manager) AddFiltering(
ip net.IP,
protocol fw.Protocol,
port *fw.Port,
direction fw.Direction,
action fw.Action,
@@ -74,27 +75,27 @@ func (m *Manager) AddFiltering(
}
}
var portValue, protocolValue string
var portValue string
if port != nil && port.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs
portValue = strconv.Itoa(port.Values[0])
switch port.Proto {
case fw.PortProtocolTCP:
protocolValue = "tcp"
case fw.PortProtocolUDP:
protocolValue = "udp"
default:
return nil, fmt.Errorf("unsupported protocol: %s", port.Proto)
}
}
ruleID := uuid.New().String()
if comment == "" {
comment = ruleID
}
specs := m.filterRuleSpecs(
"filter", ChainFilterName, ip, protocolValue,
portValue, direction, action, comment)
"filter",
ChainFilterName,
ip,
string(protocol),
portValue,
direction,
action,
comment,
)
if err := client.AppendUnique("filter", ChainFilterName, specs...); err != nil {
return nil, err
}

View File

@@ -22,8 +22,8 @@ func TestNewManager(t *testing.T) {
var rule1 fw.Rule
t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Proto: fw.PortProtocolTCP, Values: []int{8080}}
rule1, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTP traffic")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering(ip, "tcp", port, fw.DirectionDst, fw.ActionAccept, "accept HTTP traffic")
if err != nil {
t.Errorf("failed to add rule: %v", err)
}
@@ -35,11 +35,10 @@ func TestNewManager(t *testing.T) {
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
Proto: fw.PortProtocolTCP,
Values: []int{8043: 8046},
}
rule2, err = manager.AddFiltering(
ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTPS traffic from ports range")
ip, "tcp", port, fw.DirectionDst, fw.ActionAccept, "accept HTTPS traffic from ports range")
if err != nil {
t.Errorf("failed to add rule: %v", err)
}
@@ -66,8 +65,8 @@ func TestNewManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) {
// add second rule
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Proto: fw.PortProtocolUDP, Values: []int{5353}}
_, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept Fake DNS traffic")
port := &fw.Port{Values: []int{5353}}
_, err = manager.AddFiltering(ip, "udp", port, fw.DirectionDst, fw.ActionAccept, "accept Fake DNS traffic")
if err != nil {
t.Errorf("failed to add rule: %v", err)
}

View File

@@ -1,14 +1,17 @@
package firewall
// PortProtocol is the protocol of the port
type PortProtocol string
// Protocol is the protocol of the port
type Protocol string
const (
// PortProtocolTCP is the TCP protocol
PortProtocolTCP PortProtocol = "tcp"
// ProtocolTCP is the TCP protocol
ProtocolTCP Protocol = "tcp"
// PortProtocolUDP is the UDP protocol
PortProtocolUDP PortProtocol = "udp"
// ProtocolUDP is the UDP protocol
ProtocolUDP Protocol = "udp"
// ProtocolICMP is the ICMP protocol
ProtocolICMP Protocol = "icmp"
)
// Port of the address for firewall rule
@@ -18,7 +21,4 @@ type Port struct {
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
Values []int
// Proto is the protocol of the port
Proto PortProtocol
}

View File

@@ -226,6 +226,10 @@ func (e *Engine) Start() error {
e.firewallManager, err = buildFirewallManager()
if err != nil {
log.Errorf("failed to create firewall manager, ACL policy will not work: %s", err.Error())
} else {
if err := e.firewallManager.Reset(); err != nil {
log.Tracef("failed to reset firewall manager on the start: %v", err.Error())
}
}
e.firewallRules = make(map[string]firewall.Rule)
@@ -632,9 +636,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update dns server, err: %v", err)
}
if err := e.applyFirewallRules(networkMap.FirewallRules); err != nil {
log.Errorf("failed apply firewall rules, err: %v", err)
}
e.applyFirewallRules(networkMap.FirewallRules)
e.networkSerial = serial
return nil
}
@@ -1049,68 +1051,58 @@ func (e *Engine) close() {
}
// applyFirewallRules to the local firewall manager processed by ACL policy.
func (e *Engine) applyFirewallRules(rules []*mgmProto.FirewallRule) error {
func (e *Engine) applyFirewallRules(rules []*mgmProto.FirewallRule) {
if e.firewallManager == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return nil
return
}
for ruleID, rule := range e.firewallRules {
if err := e.firewallManager.DeleteRule(rule); err != nil {
log.Errorf("failed to delete firewall rule: %v", err)
continue
}
delete(e.firewallRules, ruleID)
}
newRules := make(map[string]struct{}, 0)
for _, r := range rules {
rule := e.protoRuleToFirewallRule(r)
if rule == nil {
continue
}
newRules[rule.GetRuleID()] = struct{}{}
}
for ruleID := range e.firewallRules {
if _, ok := newRules[ruleID]; ok {
continue
}
if rule, ok := e.firewallRules[ruleID]; ok {
if err := e.firewallManager.DeleteRule(rule); err != nil {
log.Debugf("failed to delete firewall rule: %v", err)
continue
}
delete(e.firewallRules, ruleID)
if rule := e.protoRuleToFirewallRule(r); rule == nil {
log.Errorf("failed to apply firewall rule: %v", r)
}
}
return nil
}
func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule {
ip := net.ParseIP(r.PeerIP)
if ip == nil {
log.Debug("invalid IP address, skipping firewall rule")
log.Error("invalid IP address, skipping firewall rule")
return nil
}
var port *firewall.Port
if r.Port != "" {
split := strings.Split(r.Port, "/")
value, err := strconv.Atoi(split[0])
if err != nil {
log.Debug("invalid port, skipping firewall rule")
return nil
}
// port can be empty, so ignore conversion error
value, _ := strconv.Atoi(split[0])
port = &firewall.Port{}
port.Values = []int{value}
// get protocol from the port suffix if it exists
if len(split) > 1 {
switch split[1] {
case "tcp":
port.Proto = firewall.PortProtocolTCP
case "udp":
port.Proto = firewall.PortProtocolUDP
default:
log.Debug("invalid protocol, skipping firewall rule")
return nil
}
if value != 0 {
port.Values = []int{value}
}
}
var protocol firewall.Protocol
switch r.Protocol {
case "tcp":
protocol = firewall.ProtocolTCP
case "udp":
protocol = firewall.ProtocolUDP
case "icmp":
protocol = firewall.ProtocolICMP
default:
log.Errorf("invalid protocol, skipping firewall rule: %q", r.Protocol)
return nil
}
var direction firewall.Direction
switch r.Direction {
case "src":
@@ -1118,7 +1110,7 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
case "dst":
direction = firewall.DirectionDst
default:
log.Debug("invalid direction, skipping firewall rule")
log.Error("invalid direction, skipping firewall rule")
return nil
}
@@ -1129,13 +1121,13 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
case "drop":
action = firewall.ActionDrop
default:
log.Debug("invalid action, skipping firewall rule")
log.Error("invalid action, skipping firewall rule")
return nil
}
rule, err := e.firewallManager.AddFiltering(ip, port, direction, action, "")
rule, err := e.firewallManager.AddFiltering(ip, protocol, port, direction, action, "")
if err != nil {
log.Debugf("failed to add firewall rule: %v", err)
log.Errorf("failed to add firewall rule: %v", err)
return nil
}
e.firewallRules[rule.GetRuleID()] = rule

View File

@@ -20,6 +20,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
@@ -28,6 +29,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
mgmt "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
@@ -938,6 +940,134 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
}
}
func TestEngine_firewallManager(t *testing.T) {
// TODO: enable when other platform will be added
if runtime.GOOS != "linux" {
t.Skipf("firewall manager not supported in the: %s", runtime.GOOS)
return
}
ctx, cancel := context.WithTimeout(CtxInitState(context.Background()), 10*time.Second)
defer cancel()
dir := t.TempDir()
err := util.CopyFileContents("../testdata/store.json", filepath.Join(dir, "store.json"))
if err != nil {
t.Errorf("copy temporary store file: %v", err)
}
sigServer, signalAddr, err := startSignal()
if err != nil {
t.Errorf("start signal server: %v", err)
return
}
defer sigServer.GracefulStop()
mgmtServer, mgmtAddr, err := startManagement(dir)
if err != nil {
t.Errorf("start management server: %v", err)
return
}
defer mgmtServer.GracefulStop()
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
engine, err := createEngine(ctx, cancel, setupKey, 0, mgmtAddr, signalAddr)
if err != nil {
t.Errorf("create engine: %v", err)
return
}
engine.dnsServer = &dns.MockServer{}
if err := engine.Start(); err != nil {
t.Logf("start engine: %v", err)
return
}
defer func() {
if err := engine.mgmClient.Close(); err != nil {
t.Logf("close management client: %v", err)
}
if err := engine.Stop(); err != nil {
t.Logf("stop engine: %v", err)
}
}()
// wait 2 seconds until all management and signal processing will be finished
time.Sleep(2 * time.Second)
if engine.firewallManager == nil {
t.Errorf("firewall manager is nil")
return
}
fwRules := []*mgmProto.FirewallRule{
{
PeerID: "test",
PeerIP: "10.93.0.1",
Direction: "dst",
Action: "accept",
Protocol: "tcp",
Port: "80",
},
{
PeerID: "test2",
PeerIP: "10.93.0.2",
Direction: "dst",
Action: "drop",
Protocol: "udp",
Port: "53",
},
}
// we receive one rule from the management so for testing purposes ignore it
engine.firewallRules = make(map[string]firewall.Rule)
t.Run("apply firewall rules", func(t *testing.T) {
engine.applyFirewallRules(fwRules)
if len(engine.firewallRules) != 2 {
t.Errorf("firewall rules not applied: %v", engine.firewallRules)
return
}
})
t.Run("add extra rules", func(t *testing.T) {
// remove first rule
fwRules = fwRules[1:]
fwRules = append(fwRules, &mgmtProto.FirewallRule{
PeerID: "test3",
PeerIP: "10.93.0.3",
Direction: "src",
Action: "drop",
Protocol: "icmp",
})
existedRulesID := map[string]struct{}{}
for id := range engine.firewallRules {
existedRulesID[id] = struct{}{}
}
engine.applyFirewallRules(fwRules)
// we should have one old and one new rule in the existed rules
if len(engine.firewallRules) != 2 {
t.Errorf("firewall rules not applied")
return
}
// check that old rules was removed
for id := range existedRulesID {
if _, ok := engine.firewallRules[id]; ok {
t.Errorf("old rule was not removed")
return
}
}
})
}
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {