mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Refactor protocol handling for firewall rules, add engine tests
This commit is contained in:
@@ -44,6 +44,7 @@ type Manager interface {
|
||||
// rule ID as comment for the rule
|
||||
AddFiltering(
|
||||
ip net.IP,
|
||||
proto Protocol,
|
||||
port *Port,
|
||||
direction Direction,
|
||||
action Action,
|
||||
|
||||
@@ -54,6 +54,7 @@ func Create() (*Manager, error) {
|
||||
// If comment is empty rule ID is used as comment
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
protocol fw.Protocol,
|
||||
port *fw.Port,
|
||||
direction fw.Direction,
|
||||
action fw.Action,
|
||||
@@ -74,27 +75,27 @@ func (m *Manager) AddFiltering(
|
||||
}
|
||||
}
|
||||
|
||||
var portValue, protocolValue string
|
||||
var portValue string
|
||||
if port != nil && port.Values != nil {
|
||||
// TODO: we support only one port per rule in current implementation of ACLs
|
||||
portValue = strconv.Itoa(port.Values[0])
|
||||
switch port.Proto {
|
||||
case fw.PortProtocolTCP:
|
||||
protocolValue = "tcp"
|
||||
case fw.PortProtocolUDP:
|
||||
protocolValue = "udp"
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", port.Proto)
|
||||
}
|
||||
}
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
if comment == "" {
|
||||
comment = ruleID
|
||||
}
|
||||
|
||||
specs := m.filterRuleSpecs(
|
||||
"filter", ChainFilterName, ip, protocolValue,
|
||||
portValue, direction, action, comment)
|
||||
"filter",
|
||||
ChainFilterName,
|
||||
ip,
|
||||
string(protocol),
|
||||
portValue,
|
||||
direction,
|
||||
action,
|
||||
comment,
|
||||
)
|
||||
if err := client.AppendUnique("filter", ChainFilterName, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -22,8 +22,8 @@ func TestNewManager(t *testing.T) {
|
||||
var rule1 fw.Rule
|
||||
t.Run("add first rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Proto: fw.PortProtocolTCP, Values: []int{8080}}
|
||||
rule1, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTP traffic")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddFiltering(ip, "tcp", port, fw.DirectionDst, fw.ActionAccept, "accept HTTP traffic")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add rule: %v", err)
|
||||
}
|
||||
@@ -35,11 +35,10 @@ func TestNewManager(t *testing.T) {
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Proto: fw.PortProtocolTCP,
|
||||
Values: []int{8043: 8046},
|
||||
}
|
||||
rule2, err = manager.AddFiltering(
|
||||
ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
||||
ip, "tcp", port, fw.DirectionDst, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add rule: %v", err)
|
||||
}
|
||||
@@ -66,8 +65,8 @@ func TestNewManager(t *testing.T) {
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Proto: fw.PortProtocolUDP, Values: []int{5353}}
|
||||
_, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept Fake DNS traffic")
|
||||
port := &fw.Port{Values: []int{5353}}
|
||||
_, err = manager.AddFiltering(ip, "udp", port, fw.DirectionDst, fw.ActionAccept, "accept Fake DNS traffic")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add rule: %v", err)
|
||||
}
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
package firewall
|
||||
|
||||
// PortProtocol is the protocol of the port
|
||||
type PortProtocol string
|
||||
// Protocol is the protocol of the port
|
||||
type Protocol string
|
||||
|
||||
const (
|
||||
// PortProtocolTCP is the TCP protocol
|
||||
PortProtocolTCP PortProtocol = "tcp"
|
||||
// ProtocolTCP is the TCP protocol
|
||||
ProtocolTCP Protocol = "tcp"
|
||||
|
||||
// PortProtocolUDP is the UDP protocol
|
||||
PortProtocolUDP PortProtocol = "udp"
|
||||
// ProtocolUDP is the UDP protocol
|
||||
ProtocolUDP Protocol = "udp"
|
||||
|
||||
// ProtocolICMP is the ICMP protocol
|
||||
ProtocolICMP Protocol = "icmp"
|
||||
)
|
||||
|
||||
// Port of the address for firewall rule
|
||||
@@ -18,7 +21,4 @@ type Port struct {
|
||||
|
||||
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
||||
Values []int
|
||||
|
||||
// Proto is the protocol of the port
|
||||
Proto PortProtocol
|
||||
}
|
||||
|
||||
@@ -226,6 +226,10 @@ func (e *Engine) Start() error {
|
||||
e.firewallManager, err = buildFirewallManager()
|
||||
if err != nil {
|
||||
log.Errorf("failed to create firewall manager, ACL policy will not work: %s", err.Error())
|
||||
} else {
|
||||
if err := e.firewallManager.Reset(); err != nil {
|
||||
log.Tracef("failed to reset firewall manager on the start: %v", err.Error())
|
||||
}
|
||||
}
|
||||
e.firewallRules = make(map[string]firewall.Rule)
|
||||
|
||||
@@ -632,9 +636,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
if err := e.applyFirewallRules(networkMap.FirewallRules); err != nil {
|
||||
log.Errorf("failed apply firewall rules, err: %v", err)
|
||||
}
|
||||
e.applyFirewallRules(networkMap.FirewallRules)
|
||||
e.networkSerial = serial
|
||||
return nil
|
||||
}
|
||||
@@ -1049,68 +1051,58 @@ func (e *Engine) close() {
|
||||
}
|
||||
|
||||
// applyFirewallRules to the local firewall manager processed by ACL policy.
|
||||
func (e *Engine) applyFirewallRules(rules []*mgmProto.FirewallRule) error {
|
||||
func (e *Engine) applyFirewallRules(rules []*mgmProto.FirewallRule) {
|
||||
if e.firewallManager == nil {
|
||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
for ruleID, rule := range e.firewallRules {
|
||||
if err := e.firewallManager.DeleteRule(rule); err != nil {
|
||||
log.Errorf("failed to delete firewall rule: %v", err)
|
||||
continue
|
||||
}
|
||||
delete(e.firewallRules, ruleID)
|
||||
}
|
||||
|
||||
newRules := make(map[string]struct{}, 0)
|
||||
for _, r := range rules {
|
||||
rule := e.protoRuleToFirewallRule(r)
|
||||
if rule == nil {
|
||||
continue
|
||||
}
|
||||
newRules[rule.GetRuleID()] = struct{}{}
|
||||
}
|
||||
|
||||
for ruleID := range e.firewallRules {
|
||||
if _, ok := newRules[ruleID]; ok {
|
||||
continue
|
||||
}
|
||||
if rule, ok := e.firewallRules[ruleID]; ok {
|
||||
if err := e.firewallManager.DeleteRule(rule); err != nil {
|
||||
log.Debugf("failed to delete firewall rule: %v", err)
|
||||
continue
|
||||
}
|
||||
delete(e.firewallRules, ruleID)
|
||||
if rule := e.protoRuleToFirewallRule(r); rule == nil {
|
||||
log.Errorf("failed to apply firewall rule: %v", r)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule {
|
||||
ip := net.ParseIP(r.PeerIP)
|
||||
if ip == nil {
|
||||
log.Debug("invalid IP address, skipping firewall rule")
|
||||
log.Error("invalid IP address, skipping firewall rule")
|
||||
return nil
|
||||
}
|
||||
|
||||
var port *firewall.Port
|
||||
if r.Port != "" {
|
||||
split := strings.Split(r.Port, "/")
|
||||
value, err := strconv.Atoi(split[0])
|
||||
if err != nil {
|
||||
log.Debug("invalid port, skipping firewall rule")
|
||||
return nil
|
||||
}
|
||||
// port can be empty, so ignore conversion error
|
||||
value, _ := strconv.Atoi(split[0])
|
||||
port = &firewall.Port{}
|
||||
port.Values = []int{value}
|
||||
// get protocol from the port suffix if it exists
|
||||
if len(split) > 1 {
|
||||
switch split[1] {
|
||||
case "tcp":
|
||||
port.Proto = firewall.PortProtocolTCP
|
||||
case "udp":
|
||||
port.Proto = firewall.PortProtocolUDP
|
||||
default:
|
||||
log.Debug("invalid protocol, skipping firewall rule")
|
||||
return nil
|
||||
}
|
||||
if value != 0 {
|
||||
port.Values = []int{value}
|
||||
}
|
||||
}
|
||||
|
||||
var protocol firewall.Protocol
|
||||
switch r.Protocol {
|
||||
case "tcp":
|
||||
protocol = firewall.ProtocolTCP
|
||||
case "udp":
|
||||
protocol = firewall.ProtocolUDP
|
||||
case "icmp":
|
||||
protocol = firewall.ProtocolICMP
|
||||
default:
|
||||
log.Errorf("invalid protocol, skipping firewall rule: %q", r.Protocol)
|
||||
return nil
|
||||
}
|
||||
|
||||
var direction firewall.Direction
|
||||
switch r.Direction {
|
||||
case "src":
|
||||
@@ -1118,7 +1110,7 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
|
||||
case "dst":
|
||||
direction = firewall.DirectionDst
|
||||
default:
|
||||
log.Debug("invalid direction, skipping firewall rule")
|
||||
log.Error("invalid direction, skipping firewall rule")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1129,13 +1121,13 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
|
||||
case "drop":
|
||||
action = firewall.ActionDrop
|
||||
default:
|
||||
log.Debug("invalid action, skipping firewall rule")
|
||||
log.Error("invalid action, skipping firewall rule")
|
||||
return nil
|
||||
}
|
||||
|
||||
rule, err := e.firewallManager.AddFiltering(ip, port, direction, action, "")
|
||||
rule, err := e.firewallManager.AddFiltering(ip, protocol, port, direction, action, "")
|
||||
if err != nil {
|
||||
log.Debugf("failed to add firewall rule: %v", err)
|
||||
log.Errorf("failed to add firewall rule: %v", err)
|
||||
return nil
|
||||
}
|
||||
e.firewallRules[rule.GetRuleID()] = rule
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
@@ -28,6 +29,7 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgmt "github.com/netbirdio/netbird/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -938,6 +940,134 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_firewallManager(t *testing.T) {
|
||||
// TODO: enable when other platform will be added
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("firewall manager not supported in the: %s", runtime.GOOS)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(CtxInitState(context.Background()), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
dir := t.TempDir()
|
||||
|
||||
err := util.CopyFileContents("../testdata/store.json", filepath.Join(dir, "store.json"))
|
||||
if err != nil {
|
||||
t.Errorf("copy temporary store file: %v", err)
|
||||
}
|
||||
|
||||
sigServer, signalAddr, err := startSignal()
|
||||
if err != nil {
|
||||
t.Errorf("start signal server: %v", err)
|
||||
return
|
||||
}
|
||||
defer sigServer.GracefulStop()
|
||||
|
||||
mgmtServer, mgmtAddr, err := startManagement(dir)
|
||||
if err != nil {
|
||||
t.Errorf("start management server: %v", err)
|
||||
return
|
||||
}
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
engine, err := createEngine(ctx, cancel, setupKey, 0, mgmtAddr, signalAddr)
|
||||
if err != nil {
|
||||
t.Errorf("create engine: %v", err)
|
||||
return
|
||||
}
|
||||
engine.dnsServer = &dns.MockServer{}
|
||||
|
||||
if err := engine.Start(); err != nil {
|
||||
t.Logf("start engine: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := engine.mgmClient.Close(); err != nil {
|
||||
t.Logf("close management client: %v", err)
|
||||
}
|
||||
|
||||
if err := engine.Stop(); err != nil {
|
||||
t.Logf("stop engine: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait 2 seconds until all management and signal processing will be finished
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
if engine.firewallManager == nil {
|
||||
t.Errorf("firewall manager is nil")
|
||||
return
|
||||
}
|
||||
|
||||
fwRules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerID: "test",
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: "dst",
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
},
|
||||
{
|
||||
PeerID: "test2",
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: "dst",
|
||||
Action: "drop",
|
||||
Protocol: "udp",
|
||||
Port: "53",
|
||||
},
|
||||
}
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
engine.firewallRules = make(map[string]firewall.Rule)
|
||||
|
||||
t.Run("apply firewall rules", func(t *testing.T) {
|
||||
engine.applyFirewallRules(fwRules)
|
||||
|
||||
if len(engine.firewallRules) != 2 {
|
||||
t.Errorf("firewall rules not applied: %v", engine.firewallRules)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("add extra rules", func(t *testing.T) {
|
||||
// remove first rule
|
||||
fwRules = fwRules[1:]
|
||||
fwRules = append(fwRules, &mgmtProto.FirewallRule{
|
||||
PeerID: "test3",
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: "src",
|
||||
Action: "drop",
|
||||
Protocol: "icmp",
|
||||
})
|
||||
|
||||
existedRulesID := map[string]struct{}{}
|
||||
for id := range engine.firewallRules {
|
||||
existedRulesID[id] = struct{}{}
|
||||
}
|
||||
|
||||
engine.applyFirewallRules(fwRules)
|
||||
|
||||
// we should have one old and one new rule in the existed rules
|
||||
if len(engine.firewallRules) != 2 {
|
||||
t.Errorf("firewall rules not applied")
|
||||
return
|
||||
}
|
||||
|
||||
// check that old rules was removed
|
||||
for id := range existedRulesID {
|
||||
if _, ok := engine.firewallRules[id]; ok {
|
||||
t.Errorf("old rule was not removed")
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user