mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-23 02:36:42 +00:00
Code cleaning in firewall package
This commit is contained in:
@@ -18,7 +18,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
@@ -167,11 +167,11 @@ func (r *router) createContainers() error {
|
||||
func (r *router) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
proto types.Protocol,
|
||||
sPort *types.Port,
|
||||
dPort *types.Port,
|
||||
action types.Action,
|
||||
) (types.Rule, error) {
|
||||
|
||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||
@@ -200,7 +200,7 @@ func (r *router) AddRouteFiltering(
|
||||
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
|
||||
|
||||
// Handle protocol
|
||||
if proto != firewall.ProtocolALL {
|
||||
if proto != types.ProtocolALL {
|
||||
protoNum, err := protoToInt(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
@@ -219,7 +219,7 @@ func (r *router) AddRouteFiltering(
|
||||
exprs = append(exprs, &expr.Counter{})
|
||||
|
||||
var verdict expr.VerdictKind
|
||||
if action == firewall.ActionAccept {
|
||||
if action == types.ActionAccept {
|
||||
verdict = expr.VerdictAccept
|
||||
} else {
|
||||
verdict = expr.VerdictDrop
|
||||
@@ -248,7 +248,7 @@ func (r *router) AddRouteFiltering(
|
||||
}
|
||||
|
||||
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
|
||||
setName := firewall.GenerateSetName(sources)
|
||||
setName := types.GenerateSetName(sources)
|
||||
ref, err := r.ipsetCounter.Increment(setName, sources)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
|
||||
@@ -270,7 +270,7 @@ func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
func (r *router) DeleteRouteRule(rule types.Rule) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
@@ -307,7 +307,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
|
||||
// overlapping prefixes will result in an error, so we need to merge them
|
||||
sources = firewall.MergeIPRanges(sources)
|
||||
sources = mergeIPRanges(sources)
|
||||
|
||||
set := &nftables.Set{
|
||||
Name: setName,
|
||||
@@ -403,7 +403,7 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
||||
}
|
||||
|
||||
// AddNatRule appends a nftables rule pair to the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
func (r *router) AddNatRule(pair types.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
@@ -420,7 +420,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("add nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
if err := r.addNatRule(types.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -433,7 +433,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
func (r *router) addNatRule(pair types.RouterPair) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
@@ -494,7 +494,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
},
|
||||
)
|
||||
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
ruleKey := types.GenRuleKey(types.PreroutingFormat, pair)
|
||||
|
||||
if _, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
@@ -584,7 +584,7 @@ func (r *router) addPostroutingRules() error {
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
func (r *router) addLegacyRouteRule(pair types.RouterPair) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
@@ -597,7 +597,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
|
||||
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
|
||||
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
ruleKey := types.GenRuleKey(types.ForwardingFormat, pair)
|
||||
|
||||
if _, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
@@ -615,8 +615,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
func (r *router) removeLegacyRouteRule(pair types.RouterPair) error {
|
||||
ruleKey := types.GenRuleKey(types.ForwardingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
@@ -651,7 +651,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
|
||||
var merr *multierror.Error
|
||||
for k, rule := range r.rules {
|
||||
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
||||
if !strings.HasPrefix(k, types.ForwardingFormatPrefix) {
|
||||
continue
|
||||
}
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
@@ -829,7 +829,7 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
||||
}
|
||||
|
||||
// RemoveNatRule removes the prerouting mark rule
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
func (r *router) RemoveNatRule(pair types.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
@@ -838,7 +838,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
if err := r.removeNatRule(types.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||
}
|
||||
|
||||
@@ -854,8 +854,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
func (r *router) removeNatRule(pair types.RouterPair) error {
|
||||
ruleKey := types.GenRuleKey(types.PreroutingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
err := r.conn.DelRule(rule)
|
||||
@@ -931,7 +931,7 @@ func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any
|
||||
}
|
||||
}
|
||||
|
||||
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
func applyPort(port *types.Port, isSource bool) []expr.Any {
|
||||
if port == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -987,3 +987,27 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
|
||||
return exprs
|
||||
}
|
||||
|
||||
func mergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||
if len(prefixes) == 0 {
|
||||
return prefixes
|
||||
}
|
||||
|
||||
merged := []netip.Prefix{prefixes[0]}
|
||||
for _, prefix := range prefixes[1:] {
|
||||
last := merged[len(merged)-1]
|
||||
if last.Contains(prefix.Addr()) {
|
||||
// If the current prefix is contained within the last merged prefix, skip it
|
||||
continue
|
||||
}
|
||||
if prefix.Contains(last.Addr()) {
|
||||
// If the current prefix contains the last merged prefix, replace it
|
||||
merged[len(merged)-1] = prefix
|
||||
} else {
|
||||
// Otherwise, add the current prefix to the merged list
|
||||
merged = append(merged, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user