Fix/acl for forward (#1305)

Fix ACL on routed traffic and code refactor
This commit is contained in:
Zoltan Papp
2023-12-08 10:48:21 +01:00
committed by GitHub
parent b03343bc4d
commit 006ba32086
50 changed files with 3720 additions and 3627 deletions

View File

@@ -11,42 +11,27 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/iface"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
SetFilter(iface.PacketFilter) error
}
// Manager is a ACL rules manager
type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap)
Stop()
}
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
manager firewall.Manager
firewall firewall.Manager
ipsetCounter int
rulesPairs map[string][]firewall.Rule
mutex sync.Mutex
}
type ipsetInfo struct {
name string
ipCount int
}
func newDefaultManager(fm firewall.Manager) *DefaultManager {
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{
manager: fm,
firewall: fm,
rulesPairs: make(map[string][]firewall.Rule),
}
}
@@ -69,13 +54,13 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
time.Since(start), total)
}()
if d.manager == nil {
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
defer func() {
if err := d.manager.Flush(); err != nil {
if err := d.firewall.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
}
}()
@@ -125,57 +110,35 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
)
}
applyFailed := false
newRulePairs := make(map[string][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
// calculate which IP's can be grouped in by which ipset
// to do that we use rule selector (which is just rule properties without IP's)
for _, r := range rules {
selector := d.getRuleGroupingSelector(r)
ipset, ok := ipsetByRuleSelectors[selector]
if !ok {
ipset = &ipsetInfo{}
}
ipset.ipCount++
ipsetByRuleSelectors[selector] = ipset
}
ipsetByRuleSelectors := make(map[string]string)
for _, r := range rules {
// if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
if ipset.name == "" {
selector := d.getRuleGroupingSelector(r)
ipsetName, ok := ipsetByRuleSelectors[selector]
if !ok {
d.ipsetCounter++
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
ipsetByRuleSelectors[selector] = ipsetName
}
ipsetName := ipset.name
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil {
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
applyFailed = true
d.rollBack(newRulePairs)
break
}
newRulePairs[pairID] = rulePair
}
if applyFailed {
log.Error("failed to apply firewall rules, rollback ACL to previous state")
for _, rules := range newRulePairs {
for _, rule := range rules {
if err := d.manager.DeleteRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
continue
}
}
if len(rules) > 0 {
d.rulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
return
}
for pairID, rules := range d.rulesPairs {
if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules {
if err := d.manager.DeleteRule(rule); err != nil {
if err := d.firewall.DeleteRule(rule); err != nil {
log.Errorf("failed to delete firewall rule: %v", err)
continue
}
@@ -186,16 +149,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
d.rulesPairs = newRulePairs
}
// Stop ACL controller and clear firewall state
func (d *DefaultManager) Stop() {
d.mutex.Lock()
defer d.mutex.Unlock()
if err := d.manager.Reset(); err != nil {
log.WithError(err).Error("reset firewall state")
}
}
func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
@@ -205,14 +158,14 @@ func (d *DefaultManager) protoRuleToFirewallRule(
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
}
protocol := convertToFirewallProtocol(r.Protocol)
if protocol == firewall.ProtocolUnknown {
return "", nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol)
protocol, err := convertToFirewallProtocol(r.Protocol)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
}
action := convertFirewallAction(r.Action)
if action == firewall.ActionUnknown {
return "", nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action)
action, err := convertFirewallAction(r.Action)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
}
var port *firewall.Port
@@ -232,7 +185,6 @@ func (d *DefaultManager) protoRuleToFirewallRule(
}
var rules []firewall.Rule
var err error
switch r.Direction {
case mgmProto.FirewallRule_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
@@ -246,7 +198,6 @@ func (d *DefaultManager) protoRuleToFirewallRule(
return "", nil, err
}
d.rulesPairs[ruleID] = rules
return ruleID, rules, nil
}
@@ -259,24 +210,24 @@ func (d *DefaultManager) addInRules(
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.manager.AddFiltering(
rule, err := d.firewall.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule)
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) {
return rules, nil
}
rule, err = d.manager.AddFiltering(
rule, err = d.firewall.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
return append(rules, rule), nil
return append(rules, rule...), nil
}
func (d *DefaultManager) addOutRules(
@@ -288,24 +239,24 @@ func (d *DefaultManager) addOutRules(
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.manager.AddFiltering(
rule, err := d.firewall.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule)
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) {
return rules, nil
}
rule, err = d.manager.AddFiltering(
rule, err = d.firewall.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
return append(rules, rule), nil
return append(rules, rule...), nil
}
// getRuleID() returns unique ID for the rule based on its parameters.
@@ -461,18 +412,29 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
}
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
log.Debugf("rollback ACL to previous state")
for _, rules := range newRulePairs {
for _, rule := range rules {
if err := d.firewall.DeleteRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
}
}
}
}
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
switch protocol {
case mgmProto.FirewallRule_TCP:
return firewall.ProtocolTCP
return firewall.ProtocolTCP, nil
case mgmProto.FirewallRule_UDP:
return firewall.ProtocolUDP
return firewall.ProtocolUDP, nil
case mgmProto.FirewallRule_ICMP:
return firewall.ProtocolICMP
return firewall.ProtocolICMP, nil
case mgmProto.FirewallRule_ALL:
return firewall.ProtocolALL
return firewall.ProtocolALL, nil
default:
return firewall.ProtocolUnknown
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
}
}
@@ -480,13 +442,13 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
}
func convertFirewallAction(action mgmProto.FirewallRuleAction) firewall.Action {
func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
switch action {
case mgmProto.FirewallRule_ACCEPT:
return firewall.ActionAccept
return firewall.ActionAccept, nil
case mgmProto.FirewallRule_DROP:
return firewall.ActionDrop
return firewall.ActionDrop, nil
default:
return firewall.ActionUnknown
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
}
}