Code cleaning in firewall package

This commit is contained in:
Zoltán Papp
2025-01-25 20:29:06 +01:00
parent 8185614362
commit efa8c17d27
42 changed files with 889 additions and 868 deletions

View File

@@ -18,7 +18,7 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
@@ -167,11 +167,11 @@ func (r *router) createContainers() error {
func (r *router) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
) (types.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
@@ -200,7 +200,7 @@ func (r *router) AddRouteFiltering(
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
// Handle protocol
if proto != firewall.ProtocolALL {
if proto != types.ProtocolALL {
protoNum, err := protoToInt(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
@@ -219,7 +219,7 @@ func (r *router) AddRouteFiltering(
exprs = append(exprs, &expr.Counter{})
var verdict expr.VerdictKind
if action == firewall.ActionAccept {
if action == types.ActionAccept {
verdict = expr.VerdictAccept
} else {
verdict = expr.VerdictDrop
@@ -248,7 +248,7 @@ func (r *router) AddRouteFiltering(
}
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
setName := firewall.GenerateSetName(sources)
setName := types.GenerateSetName(sources)
ref, err := r.ipsetCounter.Increment(setName, sources)
if err != nil {
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
@@ -270,7 +270,7 @@ func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr
return exprs, nil
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
func (r *router) DeleteRouteRule(rule types.Rule) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -307,7 +307,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them
sources = firewall.MergeIPRanges(sources)
sources = mergeIPRanges(sources)
set := &nftables.Set{
Name: setName,
@@ -403,7 +403,7 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
}
// AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
func (r *router) AddNatRule(pair types.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -420,7 +420,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
if err := r.addNatRule(types.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
}
@@ -433,7 +433,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
func (r *router) addNatRule(pair types.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
@@ -494,7 +494,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
},
)
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
ruleKey := types.GenRuleKey(types.PreroutingFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeNatRule(pair); err != nil {
@@ -584,7 +584,7 @@ func (r *router) addPostroutingRules() error {
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
func (r *router) addLegacyRouteRule(pair types.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
@@ -597,7 +597,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
ruleKey := types.GenRuleKey(types.ForwardingFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeLegacyRouteRule(pair); err != nil {
@@ -615,8 +615,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
}
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
func (r *router) removeLegacyRouteRule(pair types.RouterPair) error {
ruleKey := types.GenRuleKey(types.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
@@ -651,7 +651,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
if !strings.HasPrefix(k, types.ForwardingFormatPrefix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
@@ -829,7 +829,7 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
}
// RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
func (r *router) RemoveNatRule(pair types.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -838,7 +838,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove prerouting rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
if err := r.removeNatRule(types.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
}
@@ -854,8 +854,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return nil
}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
func (r *router) removeNatRule(pair types.RouterPair) error {
ruleKey := types.GenRuleKey(types.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule)
@@ -931,7 +931,7 @@ func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any
}
}
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
func applyPort(port *types.Port, isSource bool) []expr.Any {
if port == nil {
return nil
}
@@ -987,3 +987,27 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
return exprs
}
func mergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 {
return prefixes
}
merged := []netip.Prefix{prefixes[0]}
for _, prefix := range prefixes[1:] {
last := merged[len(merged)-1]
if last.Contains(prefix.Addr()) {
// If the current prefix is contained within the last merged prefix, skip it
continue
}
if prefix.Contains(last.Addr()) {
// If the current prefix contains the last merged prefix, replace it
merged[len(merged)-1] = prefix
} else {
// Otherwise, add the current prefix to the merged list
merged = append(merged, prefix)
}
}
return merged
}