[management, client] Add access control support to network routes (#2100)

This commit is contained in:
Bethuel Mmbaga
2024-10-02 14:41:00 +03:00
committed by GitHub
parent a3a479429e
commit ff7863785f
48 changed files with 4683 additions and 2444 deletions

View File

@@ -0,0 +1,25 @@
package id
import (
"fmt"
"net/netip"
"github.com/netbirdio/netbird/client/firewall/manager"
)
type RuleID string
func (r RuleID) GetRuleID() string {
return string(r)
}
func GenerateRouteRuleKey(
sources []netip.Prefix,
destination netip.Prefix,
proto manager.Protocol,
sPort *manager.Port,
dPort *manager.Port,
action manager.Action,
) RuleID {
return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action))
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/hex"
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"time"
@@ -12,6 +13,7 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
@@ -23,16 +25,18 @@ type Manager interface {
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
firewall firewall.Manager
ipsetCounter int
rulesPairs map[string][]firewall.Rule
mutex sync.Mutex
firewall firewall.Manager
ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule
routeRules map[id.RuleID]struct{}
mutex sync.Mutex
}
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{
firewall: fm,
rulesPairs: make(map[string][]firewall.Rule),
firewall: fm,
peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
routeRules: make(map[id.RuleID]struct{}),
}
}
@@ -46,7 +50,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
start := time.Now()
defer func() {
total := 0
for _, pairs := range d.rulesPairs {
for _, pairs := range d.peerRulesPairs {
total += len(pairs)
}
log.Infof(
@@ -59,21 +63,34 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
return
}
defer func() {
if err := d.firewall.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
}
}()
d.applyPeerACLs(networkMap)
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
// then the mgmt server is older than the client, and we need to allow all traffic for routes
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
log.Errorf("failed to set legacy management flag: %v", err)
}
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err)
}
if err := d.firewall.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
}
}
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules, squashedProtocols := d.squashAcceptRules(networkMap)
enableSSH := (networkMap.PeerConfig != nil &&
enableSSH := networkMap.PeerConfig != nil &&
networkMap.PeerConfig.SshConfig != nil &&
networkMap.PeerConfig.SshConfig.SshEnabled)
if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
networkMap.PeerConfig.SshConfig.SshEnabled
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
enableSSH = enableSSH && !ok
}
if _, ok := squashedProtocols[mgmProto.FirewallRule_TCP]; ok {
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
enableSSH = enableSSH && !ok
}
@@ -83,9 +100,9 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
if enableSSH {
rules = append(rules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_TCP,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: strconv.Itoa(ssh.DefaultSSHPort),
})
}
@@ -97,20 +114,20 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
rules = append(rules,
&mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
&mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
)
}
newRulePairs := make(map[string][]firewall.Rule)
newRulePairs := make(map[id.RuleID][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]string)
for _, r := range rules {
@@ -130,29 +147,97 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
break
}
if len(rules) > 0 {
d.rulesPairs[pairID] = rulePair
d.peerRulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
}
for pairID, rules := range d.rulesPairs {
for pairID, rules := range d.peerRulesPairs {
if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules {
if err := d.firewall.DeleteRule(rule); err != nil {
log.Errorf("failed to delete firewall rule: %v", err)
if err := d.firewall.DeletePeerRule(rule); err != nil {
log.Errorf("failed to delete peer firewall rule: %v", err)
continue
}
}
delete(d.rulesPairs, pairID)
delete(d.peerRulesPairs, pairID)
}
}
d.rulesPairs = newRulePairs
d.peerRulesPairs = newRulePairs
}
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
var newRouteRules = make(map[id.RuleID]struct{})
for _, rule := range rules {
id, err := d.applyRouteACL(rule)
if err != nil {
return fmt.Errorf("apply route ACL: %w", err)
}
newRouteRules[id] = struct{}{}
}
for id := range d.routeRules {
if _, ok := newRouteRules[id]; !ok {
if err := d.firewall.DeleteRouteRule(id); err != nil {
log.Errorf("failed to delete route firewall rule: %v", err)
continue
}
delete(d.routeRules, id)
}
}
d.routeRules = newRouteRules
return nil
}
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
if len(rule.SourceRanges) == 0 {
return "", fmt.Errorf("source ranges is empty")
}
var sources []netip.Prefix
for _, sourceRange := range rule.SourceRanges {
source, err := netip.ParsePrefix(sourceRange)
if err != nil {
return "", fmt.Errorf("parse source range: %w", err)
}
sources = append(sources, source)
}
var destination netip.Prefix
if rule.IsDynamic {
destination = getDefault(sources[0])
} else {
var err error
destination, err = netip.ParsePrefix(rule.Destination)
if err != nil {
return "", fmt.Errorf("parse destination: %w", err)
}
}
protocol, err := convertToFirewallProtocol(rule.Protocol)
if err != nil {
return "", fmt.Errorf("invalid protocol: %w", err)
}
action, err := convertFirewallAction(rule.Action)
if err != nil {
return "", fmt.Errorf("invalid action: %w", err)
}
dPorts := convertPortInfo(rule.PortInfo)
addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action)
if err != nil {
return "", fmt.Errorf("add route rule: %w", err)
}
return id.RuleID(addedRule.GetRuleID()), nil
}
func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
) (string, []firewall.Rule, error) {
) (id.RuleID, []firewall.Rule, error) {
ip := net.ParseIP(r.PeerIP)
if ip == nil {
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
@@ -179,16 +264,16 @@ func (d *DefaultManager) protoRuleToFirewallRule(
}
}
ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
if rulesPair, ok := d.rulesPairs[ruleID]; ok {
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "")
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
return ruleID, rulesPair, nil
}
var rules []firewall.Rule
switch r.Direction {
case mgmProto.FirewallRule_IN:
case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
case mgmProto.FirewallRule_OUT:
case mgmProto.RuleDirection_OUT:
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
@@ -210,7 +295,7 @@ func (d *DefaultManager) addInRules(
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.firewall.AddFiltering(
rule, err := d.firewall.AddPeerFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -221,7 +306,7 @@ func (d *DefaultManager) addInRules(
return rules, nil
}
rule, err = d.firewall.AddFiltering(
rule, err = d.firewall.AddPeerFiltering(
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -239,7 +324,7 @@ func (d *DefaultManager) addOutRules(
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.firewall.AddFiltering(
rule, err := d.firewall.AddPeerFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -250,7 +335,7 @@ func (d *DefaultManager) addOutRules(
return rules, nil
}
rule, err = d.firewall.AddFiltering(
rule, err = d.firewall.AddPeerFiltering(
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -259,21 +344,21 @@ func (d *DefaultManager) addOutRules(
return append(rules, rule...), nil
}
// getRuleID() returns unique ID for the rule based on its parameters.
func (d *DefaultManager) getRuleID(
// getPeerRuleID() returns unique ID for the rule based on its parameters.
func (d *DefaultManager) getPeerRuleID(
ip net.IP,
proto firewall.Protocol,
direction int,
port *firewall.Port,
action firewall.Action,
comment string,
) string {
) id.RuleID {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
if port != nil {
idStr += port.String()
}
return hex.EncodeToString(md5.New().Sum([]byte(idStr)))
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
}
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
@@ -283,7 +368,7 @@ func (d *DefaultManager) getRuleID(
// but other has port definitions or has drop policy.
func (d *DefaultManager) squashAcceptRules(
networkMap *mgmProto.NetworkMap,
) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
totalIPs := 0
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps {
@@ -291,14 +376,14 @@ func (d *DefaultManager) squashAcceptRules(
}
}
type protoMatch map[mgmProto.FirewallRuleProtocol]map[string]int
type protoMatch map[mgmProto.RuleProtocol]map[string]int
in := protoMatch{}
out := protoMatch{}
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list.
@@ -308,7 +393,7 @@ func (d *DefaultManager) squashAcceptRules(
//
// We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
drop := r.Action == mgmProto.FirewallRule_DROP || r.Port != ""
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
if drop {
protocols[r.Protocol] = map[string]int{}
return
@@ -336,7 +421,7 @@ func (d *DefaultManager) squashAcceptRules(
for i, r := range networkMap.FirewallRules {
// calculate squash for different directions
if r.Direction == mgmProto.FirewallRule_IN {
if r.Direction == mgmProto.RuleDirection_IN {
addRuleToCalculationMap(i, r, in)
} else {
addRuleToCalculationMap(i, r, out)
@@ -345,14 +430,14 @@ func (d *DefaultManager) squashAcceptRules(
// order of squashing by protocol is important
// only for their first element ALL, it must be done first
protocolOrders := []mgmProto.FirewallRuleProtocol{
mgmProto.FirewallRule_ALL,
mgmProto.FirewallRule_ICMP,
mgmProto.FirewallRule_TCP,
mgmProto.FirewallRule_UDP,
protocolOrders := []mgmProto.RuleProtocol{
mgmProto.RuleProtocol_ALL,
mgmProto.RuleProtocol_ICMP,
mgmProto.RuleProtocol_TCP,
mgmProto.RuleProtocol_UDP,
}
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
squash := func(matches protoMatch, direction mgmProto.RuleDirection) {
for _, protocol := range protocolOrders {
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
// don't squash if :
@@ -365,12 +450,12 @@ func (d *DefaultManager) squashAcceptRules(
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: direction,
Action: mgmProto.FirewallRule_ACCEPT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: protocol,
})
squashedProtocols[protocol] = struct{}{}
if protocol == mgmProto.FirewallRule_ALL {
if protocol == mgmProto.RuleProtocol_ALL {
// if we have ALL traffic type squashed rule
// it allows all other type of traffic, so we can stop processing
break
@@ -378,11 +463,11 @@ func (d *DefaultManager) squashAcceptRules(
}
}
squash(in, mgmProto.FirewallRule_IN)
squash(out, mgmProto.FirewallRule_OUT)
squash(in, mgmProto.RuleDirection_IN)
squash(out, mgmProto.RuleDirection_OUT)
// if all protocol was squashed everything is allow and we can ignore all other rules
if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
return squashedRules, squashedProtocols
}
@@ -412,26 +497,26 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
}
func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
log.Debugf("rollback ACL to previous state")
for _, rules := range newRulePairs {
for _, rule := range rules {
if err := d.firewall.DeleteRule(rule); err != nil {
if err := d.firewall.DeletePeerRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
}
}
}
}
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
switch protocol {
case mgmProto.FirewallRule_TCP:
case mgmProto.RuleProtocol_TCP:
return firewall.ProtocolTCP, nil
case mgmProto.FirewallRule_UDP:
case mgmProto.RuleProtocol_UDP:
return firewall.ProtocolUDP, nil
case mgmProto.FirewallRule_ICMP:
case mgmProto.RuleProtocol_ICMP:
return firewall.ProtocolICMP, nil
case mgmProto.FirewallRule_ALL:
case mgmProto.RuleProtocol_ALL:
return firewall.ProtocolALL, nil
default:
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
@@ -442,13 +527,41 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
}
func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
func convertFirewallAction(action mgmProto.RuleAction) (firewall.Action, error) {
switch action {
case mgmProto.FirewallRule_ACCEPT:
case mgmProto.RuleAction_ACCEPT:
return firewall.ActionAccept, nil
case mgmProto.FirewallRule_DROP:
case mgmProto.RuleAction_DROP:
return firewall.ActionDrop, nil
default:
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
}
}
func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
if portInfo == nil {
return nil
}
if portInfo.GetPort() != 0 {
return &firewall.Port{
Values: []int{int(portInfo.GetPort())},
}
}
if portInfo.GetRange() != nil {
return &firewall.Port{
IsRange: true,
Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
}
}
return nil
}
func getDefault(prefix netip.Prefix) netip.Prefix {
if prefix.Addr().Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
}
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}

View File

@@ -19,16 +19,16 @@ func TestDefaultManager(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_TCP,
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.FirewallRule_DROP,
Protocol: mgmProto.FirewallRule_UDP,
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_UDP,
Port: "53",
},
},
@@ -65,16 +65,16 @@ func TestDefaultManager(t *testing.T) {
t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap)
if len(acl.rulesPairs) != 2 {
t.Errorf("firewall rules not applied: %v", acl.rulesPairs)
if len(acl.peerRulesPairs) != 2 {
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
return
}
})
t.Run("add extra rules", func(t *testing.T) {
existedPairs := map[string]struct{}{}
for id := range acl.rulesPairs {
existedPairs[id] = struct{}{}
for id := range acl.peerRulesPairs {
existedPairs[id.GetRuleID()] = struct{}{}
}
// remove first rule
@@ -83,24 +83,24 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRules,
&mgmProto.FirewallRule{
PeerIP: "10.93.0.3",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_DROP,
Protocol: mgmProto.FirewallRule_ICMP,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_ICMP,
},
)
acl.ApplyFiltering(networkMap)
// we should have one old and one new rule in the existed rules
if len(acl.rulesPairs) != 2 {
if len(acl.peerRulesPairs) != 2 {
t.Errorf("firewall rules not applied")
return
}
// check that old rule was removed
previousCount := 0
for id := range acl.rulesPairs {
if _, ok := existedPairs[id]; ok {
for id := range acl.peerRulesPairs {
if _, ok := existedPairs[id.GetRuleID()]; ok {
previousCount++
}
}
@@ -113,15 +113,15 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true
if acl.ApplyFiltering(networkMap); len(acl.rulesPairs) != 0 {
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.rulesPairs))
if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 {
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
return
}
networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap)
if len(acl.rulesPairs) != 2 {
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.rulesPairs))
if len(acl.peerRulesPairs) != 2 {
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
return
}
})
@@ -138,51 +138,51 @@ func TestDefaultManagerSquashRules(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.1",
Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
},
}
@@ -199,13 +199,13 @@ func TestDefaultManagerSquashRules(t *testing.T) {
case r.PeerIP != "0.0.0.0":
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
return
case r.Direction != mgmProto.FirewallRule_IN:
case r.Direction != mgmProto.RuleDirection_IN:
t.Errorf("direction should be IN, got: %v", r.Direction)
return
case r.Protocol != mgmProto.FirewallRule_ALL:
case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.FirewallRule_ACCEPT:
case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
@@ -215,13 +215,13 @@ func TestDefaultManagerSquashRules(t *testing.T) {
case r.PeerIP != "0.0.0.0":
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
return
case r.Direction != mgmProto.FirewallRule_OUT:
case r.Direction != mgmProto.RuleDirection_OUT:
t.Errorf("direction should be OUT, got: %v", r.Direction)
return
case r.Protocol != mgmProto.FirewallRule_ALL:
case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.FirewallRule_ACCEPT:
case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
@@ -238,51 +238,51 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_TCP,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.1",
Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_UDP,
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
}
@@ -308,21 +308,21 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_TCP,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_TCP,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_UDP,
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
}
@@ -357,8 +357,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
acl.ApplyFiltering(networkMap)
if len(acl.rulesPairs) != 4 {
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.rulesPairs))
if len(acl.peerRulesPairs) != 4 {
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
return
}
}

View File

@@ -704,6 +704,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil
}
// Apply ACLs in the beginning to avoid security leaks
if e.acl != nil {
e.acl.ApplyFiltering(networkMap)
}
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
@@ -770,10 +775,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update dns server, err: %v", err)
}
if e.acl != nil {
e.acl.ApplyFiltering(networkMap)
}
e.networkSerial = serial
// Test received (upstream) servers for availability right away instead of upon usage.

View File

@@ -303,7 +303,7 @@ func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]neti
var merr *multierror.Error
for _, prefix := range prefixes {
if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil {
if _, err := r.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err))
continue
}

View File

@@ -87,10 +87,10 @@ func NewManager(
}
dm.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ any) (any, error) {
return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
return struct{}{}, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
},
func(prefix netip.Prefix, _ any) error {
func(prefix netip.Prefix, _ struct{}) error {
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
},
)

View File

@@ -3,7 +3,8 @@ package refcounter
import (
"errors"
"fmt"
"net/netip"
"runtime"
"strings"
"sync"
"github.com/hashicorp/go-multierror"
@@ -12,118 +13,153 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
)
// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix.
const logLevel = log.TraceLevel
// ErrIgnore can be returned by AddFunc to indicate that the counter should not be incremented for the given key.
var ErrIgnore = errors.New("ignore")
// Ref holds the reference count and associated data for a key.
type Ref[O any] struct {
Count int
Out O
}
type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error)
type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error
// AddFunc is the function type for adding a new key.
// Key is the type of the key (e.g., netip.Prefix).
type AddFunc[Key, I, O any] func(key Key, in I) (out O, err error)
type Counter[I, O any] struct {
// refCountMap keeps track of the reference Ref for prefixes
refCountMap map[netip.Prefix]Ref[O]
// RemoveFunc is the function type for removing a key.
type RemoveFunc[Key, O any] func(key Key, out O) error
// Counter is a generic reference counter for managing keys and their associated data.
// Key: The type of the key (e.g., netip.Prefix, string).
//
// I: The input type for the AddFunc. It is the input type for additional data needed
// when adding a key, it is passed as the second argument to AddFunc.
//
// O: The output type for the AddFunc and RemoveFunc. This is the output returned by AddFunc.
// It is stored and passed to RemoveFunc when the reference count reaches 0.
//
// The types can be aliased to a specific type using the following syntax:
//
// type RouteRefCounter = Counter[netip.Prefix, any, any]
type Counter[Key comparable, I, O any] struct {
// refCountMap keeps track of the reference Ref for keys
refCountMap map[Key]Ref[O]
refCountMu sync.Mutex
// idMap keeps track of the prefixes associated with an ID for removal
idMap map[string][]netip.Prefix
// idMap keeps track of the keys associated with an ID for removal
idMap map[string][]Key
idMu sync.Mutex
add AddFunc[I, O]
remove RemoveFunc[I, O]
add AddFunc[Key, I, O]
remove RemoveFunc[Key, O]
}
// New creates a new Counter instance
func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] {
return &Counter[I, O]{
refCountMap: map[netip.Prefix]Ref[O]{},
idMap: map[string][]netip.Prefix{},
// New creates a new Counter instance.
// Usage example:
//
// counter := New[netip.Prefix, string, string](
// func(key netip.Prefix, in string) (out string, err error) { ... },
// func(key netip.Prefix, out string) error { ... },`
// )
func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key, O]) *Counter[Key, I, O] {
return &Counter[Key, I, O]{
refCountMap: map[Key]Ref[O]{},
idMap: map[string][]Key{},
add: add,
remove: remove,
}
}
// Increment increments the reference count for the given prefix.
// If this is the first reference to the prefix, the AddFunc is called.
func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) {
// Get retrieves the current reference count and associated data for a key.
// If the key doesn't exist, it returns a zero value Ref and false.
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
ref := rm.refCountMap[prefix]
log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
ref, ok := rm.refCountMap[key]
return ref, ok
}
// Call AddFunc only if it's a new prefix
// Increment increments the reference count for the given key.
// If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
ref := rm.refCountMap[key]
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
// Call AddFunc only if it's a new key
if ref.Count == 0 {
log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out)
out, err := rm.add(prefix, in)
logCallerF("Calling add for key %v", key)
out, err := rm.add(key, in)
if errors.Is(err, ErrIgnore) {
return ref, nil
}
if err != nil {
return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err)
return ref, fmt.Errorf("failed to add for key %v: %w", key, err)
}
ref.Out = out
}
ref.Count++
rm.refCountMap[prefix] = ref
rm.refCountMap[key] = ref
return ref, nil
}
// IncrementWithID increments the reference count for the given prefix and groups it under the given ID.
// If this is the first reference to the prefix, the AddFunc is called.
func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) {
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
// If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
rm.idMu.Lock()
defer rm.idMu.Unlock()
ref, err := rm.Increment(prefix, in)
ref, err := rm.Increment(key, in)
if err != nil {
return ref, fmt.Errorf("with ID: %w", err)
}
rm.idMap[id] = append(rm.idMap[id], prefix)
rm.idMap[id] = append(rm.idMap[id], key)
return ref, nil
}
// Decrement decrements the reference count for the given prefix.
// Decrement decrements the reference count for the given key.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) {
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
ref, ok := rm.refCountMap[prefix]
ref, ok := rm.refCountMap[key]
if !ok {
log.Tracef("No reference found for prefix %s", prefix)
logCallerF("No reference found for key %v", key)
return ref, nil
}
log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
logCallerF("Decreasing ref count [%d -> %d] for key %v with Out [%v]", ref.Count, ref.Count-1, key, ref.Out)
if ref.Count == 1 {
log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out)
if err := rm.remove(prefix, ref.Out); err != nil {
return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err)
logCallerF("Calling remove for key %v", key)
if err := rm.remove(key, ref.Out); err != nil {
return ref, fmt.Errorf("remove for key %v: %w", key, err)
}
delete(rm.refCountMap, prefix)
delete(rm.refCountMap, key)
} else {
ref.Count--
rm.refCountMap[prefix] = ref
rm.refCountMap[key] = ref
}
return ref, nil
}
// DecrementWithID decrements the reference count for all prefixes associated with the given ID.
// DecrementWithID decrements the reference count for all keys associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[I, O]) DecrementWithID(id string) error {
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
for _, prefix := range rm.idMap[id] {
if _, err := rm.Decrement(prefix); err != nil {
for _, key := range rm.idMap[id] {
if _, err := rm.Decrement(key); err != nil {
merr = multierror.Append(merr, err)
}
}
@@ -132,24 +168,77 @@ func (rm *Counter[I, O]) DecrementWithID(id string) error {
return nberrors.FormatErrorOrNil(merr)
}
// Flush removes all references and calls RemoveFunc for each prefix.
func (rm *Counter[I, O]) Flush() error {
// Flush removes all references and calls RemoveFunc for each key.
func (rm *Counter[Key, I, O]) Flush() error {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
for prefix := range rm.refCountMap {
log.Tracef("Removing for prefix %s", prefix)
ref := rm.refCountMap[prefix]
if err := rm.remove(prefix, ref.Out); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err))
for key := range rm.refCountMap {
logCallerF("Calling remove for key %v", key)
ref := rm.refCountMap[key]
if err := rm.remove(key, ref.Out); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove for key %v: %w", key, err))
}
}
rm.refCountMap = map[netip.Prefix]Ref[O]{}
rm.idMap = map[string][]netip.Prefix{}
clear(rm.refCountMap)
clear(rm.idMap)
return nberrors.FormatErrorOrNil(merr)
}
// Clear removes all references without calling RemoveFunc.
func (rm *Counter[Key, I, O]) Clear() {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
clear(rm.refCountMap)
clear(rm.idMap)
}
func getCallerInfo(depth int, maxDepth int) (string, bool) {
if depth >= maxDepth {
return "", false
}
pc, _, _, ok := runtime.Caller(depth)
if !ok {
return "", false
}
if details := runtime.FuncForPC(pc); details != nil {
name := details.Name()
lastDotIndex := strings.LastIndex(name, "/")
if lastDotIndex != -1 {
name = name[lastDotIndex+1:]
}
if strings.HasPrefix(name, "refcounter.") {
// +2 to account for recursion
return getCallerInfo(depth+2, maxDepth)
}
return name, true
}
return "", false
}
// logCaller logs a message with the package name and method of the function that called the current function.
func logCallerF(format string, args ...interface{}) {
if log.GetLevel() < logLevel {
return
}
if callerName, ok := getCallerInfo(3, 18); ok {
format = fmt.Sprintf("[%s] %s", callerName, format)
}
log.StandardLogger().Logf(logLevel, format, args...)
}

View File

@@ -1,7 +1,9 @@
package refcounter
import "net/netip"
// RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement
type RouteRefCounter = Counter[any, any]
type RouteRefCounter = Counter[netip.Prefix, struct{}, struct{}]
// AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement
type AllowedIPsRefCounter = Counter[string, string]
type AllowedIPsRefCounter = Counter[netip.Prefix, string, string]

View File

@@ -94,7 +94,7 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
return fmt.Errorf("parse prefix: %w", err)
}
err = m.firewall.RemoveRoutingRules(routerPair)
err = m.firewall.RemoveNatRule(routerPair)
if err != nil {
return fmt.Errorf("remove routing rules: %w", err)
}
@@ -123,7 +123,7 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
return fmt.Errorf("parse prefix: %w", err)
}
err = m.firewall.InsertRoutingRules(routerPair)
err = m.firewall.AddNatRule(routerPair)
if err != nil {
return fmt.Errorf("insert routing rules: %w", err)
}
@@ -157,7 +157,7 @@ func (m *defaultServerRouter) cleanUp() {
continue
}
err = m.firewall.RemoveRoutingRules(routerPair)
err = m.firewall.RemoveNatRule(routerPair)
if err != nil {
log.Errorf("Failed to remove cleanup route: %v", err)
}
@@ -173,15 +173,15 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
// TODO: add ipv6
source := getDefaultPrefix(route.Network)
destination := route.Network.Masked().String()
destination := route.Network.Masked()
if route.IsDynamic() {
// TODO: add ipv6
destination = "0.0.0.0/0"
// TODO: add ipv6 additionally
destination = getDefaultPrefix(destination)
}
return firewall.RouterPair{
ID: string(route.ID),
Source: source.String(),
ID: route.ID,
Source: source,
Destination: destination,
Masquerade: route.Masquerade,
}, nil

View File

@@ -30,7 +30,7 @@ func (r *Route) String() string {
}
func (r *Route) AddRoute(context.Context) error {
_, err := r.routeRefCounter.Increment(r.route.Network, nil)
_, err := r.routeRefCounter.Increment(r.route.Network, struct{}{})
return err
}

View File

@@ -15,7 +15,7 @@ type Nexthop struct {
Intf *net.Interface
}
type ExclusionCounter = refcounter.Counter[any, Nexthop]
type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter

View File

@@ -41,7 +41,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
}
refCounter := refcounter.New(
func(prefix netip.Prefix, _ any) (Nexthop, error) {
func(prefix netip.Prefix, _ struct{}) (Nexthop, error) {
initialNexthop := initialNextHopV4
if prefix.Addr().Is6() {
initialNexthop = initialNextHopV6
@@ -317,7 +317,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
return fmt.Errorf("convert ip to prefix: %w", err)
}
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil {
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}