mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-26 10:39:56 +00:00
Compare commits
6 Commits
refactor/m
...
worktree-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b8a5c15cc | ||
|
|
ddc9f8199a | ||
|
|
6c6ad8d14f | ||
|
|
7461d4cef4 | ||
|
|
c761d0d1cd | ||
|
|
c46dee4e6b |
199
client/firewall/iptables/dnat_refcount_linux_test.go
Normal file
199
client/firewall/iptables/dnat_refcount_linux_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func iptRefcountIfaceV4() *iFaceMock {
|
||||
return &iFaceMock{
|
||||
NameFunc: func() string { return "wt-refcount" },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func iptRefcountIfaceDual() *iFaceMock {
|
||||
return &iFaceMock{
|
||||
NameFunc: func() string { return "wt-refcount" },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
IPv6: netip.MustParseAddr("fd00::1"),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newIptRefcountManager(t *testing.T, dual bool) *Manager {
|
||||
t.Helper()
|
||||
var ifMock *iFaceMock
|
||||
if dual {
|
||||
ifMock = iptRefcountIfaceDual()
|
||||
} else {
|
||||
ifMock = iptRefcountIfaceV4()
|
||||
}
|
||||
m, err := Create(ifMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "create manager")
|
||||
require.NoError(t, m.Init(nil), "init manager")
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, m.Close(nil), "close manager")
|
||||
})
|
||||
return m
|
||||
}
|
||||
|
||||
func iptDnatV4(port uint16) fw.ForwardRule {
|
||||
return fw.ForwardRule{
|
||||
Protocol: fw.ProtocolTCP,
|
||||
DestinationPort: fw.Port{Values: []uint16{port}},
|
||||
TranslatedAddress: netip.MustParseAddr("10.20.0.2"),
|
||||
TranslatedPort: fw.Port{Values: []uint16{80}},
|
||||
}
|
||||
}
|
||||
|
||||
func iptDnatV6(port uint16) fw.ForwardRule {
|
||||
return fw.ForwardRule{
|
||||
Protocol: fw.ProtocolTCP,
|
||||
DestinationPort: fw.Port{Values: []uint16{port}},
|
||||
TranslatedAddress: netip.MustParseAddr("fd00::2"),
|
||||
TranslatedPort: fw.Port{Values: []uint16{80}},
|
||||
}
|
||||
}
|
||||
|
||||
// TestIptablesDNAT_RefcountBalancedV4 covers a Balanced Add/Delete pair on v4.
|
||||
func TestIptablesDNAT_RefcountBalancedV4(t *testing.T) {
|
||||
m := newIptRefcountManager(t, false)
|
||||
state := m.router.ipFwdState
|
||||
|
||||
r1, err := m.AddDNATRule(iptDnatV4(7081))
|
||||
require.NoError(t, err, "add v4 dnat 1")
|
||||
v4, v6 := state.Counts()
|
||||
require.Equal(t, 1, v4, "v4 refcount after first add")
|
||||
require.Equal(t, 0, v6, "v6 refcount unchanged")
|
||||
|
||||
r2, err := m.AddDNATRule(iptDnatV4(7082))
|
||||
require.NoError(t, err, "add v4 dnat 2")
|
||||
v4, v6 = state.Counts()
|
||||
require.Equal(t, 2, v4, "v4 refcount after second add")
|
||||
require.Equal(t, 0, v6, "v6 refcount unchanged")
|
||||
|
||||
require.NoError(t, m.DeleteDNATRule(r1))
|
||||
v4, v6 = state.Counts()
|
||||
require.Equal(t, 1, v4, "v4 refcount after first delete")
|
||||
require.Equal(t, 0, v6, "v6 refcount unchanged")
|
||||
|
||||
require.NoError(t, m.DeleteDNATRule(r2))
|
||||
v4, v6 = state.Counts()
|
||||
require.Equal(t, 0, v4, "v4 refcount after second delete")
|
||||
require.Equal(t, 0, v6, "v6 refcount unchanged")
|
||||
}
|
||||
|
||||
// TestIptablesDNAT_RefcountBalancedV6 checks the v6 path increments v6 only and
|
||||
// decrements back to zero.
|
||||
func TestIptablesDNAT_RefcountBalancedV6(t *testing.T) {
|
||||
m := newIptRefcountManager(t, true)
|
||||
require.NotNil(t, m.router6, "v6 router")
|
||||
require.Same(t, m.router.ipFwdState, m.router6.ipFwdState, "shared state")
|
||||
state := m.router.ipFwdState
|
||||
|
||||
r1, err := m.AddDNATRule(iptDnatV6(9081))
|
||||
require.NoError(t, err, "add v6 dnat 1")
|
||||
v4, v6 := state.Counts()
|
||||
require.Equal(t, 0, v4)
|
||||
require.Equal(t, 1, v6, "v6 refcount after first add")
|
||||
|
||||
r2, err := m.AddDNATRule(iptDnatV6(9082))
|
||||
require.NoError(t, err, "add v6 dnat 2")
|
||||
v4, v6 = state.Counts()
|
||||
require.Equal(t, 0, v4, "v4 refcount unchanged")
|
||||
require.Equal(t, 2, v6, "v6 refcount after second add")
|
||||
|
||||
require.NoError(t, m.DeleteDNATRule(r1))
|
||||
v4, v6 = state.Counts()
|
||||
require.Equal(t, 0, v4, "v4 refcount unchanged")
|
||||
require.Equal(t, 1, v6, "v6 refcount after first delete")
|
||||
|
||||
require.NoError(t, m.DeleteDNATRule(r2))
|
||||
v4, v6 = state.Counts()
|
||||
require.Equal(t, 0, v4)
|
||||
require.Equal(t, 0, v6, "v6 refcount after second delete")
|
||||
}
|
||||
|
||||
// TestIptablesDNAT_DuplicateAddNoLeak verifies the duplicate-rule path returns
|
||||
// without bumping the refcount.
|
||||
func TestIptablesDNAT_DuplicateAddNoLeak(t *testing.T) {
|
||||
m := newIptRefcountManager(t, true)
|
||||
state := m.router.ipFwdState
|
||||
|
||||
rule := iptDnatV4(7083)
|
||||
r1, err := m.AddDNATRule(rule)
|
||||
require.NoError(t, err)
|
||||
v4, _ := state.Counts()
|
||||
require.Equal(t, 1, v4)
|
||||
|
||||
_, err = m.AddDNATRule(rule)
|
||||
require.NoError(t, err, "duplicate add")
|
||||
v4, _ = state.Counts()
|
||||
require.Equal(t, 1, v4, "duplicate add must not increment")
|
||||
|
||||
require.NoError(t, m.DeleteDNATRule(r1))
|
||||
v4, _ = state.Counts()
|
||||
require.Equal(t, 0, v4, "single delete must drop to zero")
|
||||
}
|
||||
|
||||
// TestIptablesDNAT_DeleteMissingNoUnderflow verifies Delete on an unknown rule
|
||||
// neither errors nor releases the refcount.
|
||||
func TestIptablesDNAT_DeleteMissingNoUnderflow(t *testing.T) {
|
||||
m := newIptRefcountManager(t, true)
|
||||
state := m.router.ipFwdState
|
||||
|
||||
phantom := iptDnatV4(7099)
|
||||
require.NoError(t, m.DeleteDNATRule(&phantom), "delete missing v4")
|
||||
v4, v6 := state.Counts()
|
||||
require.Equal(t, 0, v4)
|
||||
require.Equal(t, 0, v6)
|
||||
|
||||
phantom6 := iptDnatV6(9099)
|
||||
require.NoError(t, m.DeleteDNATRule(&phantom6), "delete missing v6")
|
||||
v4, v6 = state.Counts()
|
||||
require.Equal(t, 0, v4)
|
||||
require.Equal(t, 0, v6)
|
||||
|
||||
r1, err := m.AddDNATRule(iptDnatV4(7100))
|
||||
require.NoError(t, err)
|
||||
v4, _ = state.Counts()
|
||||
require.Equal(t, 1, v4, "real add still increments after phantom delete")
|
||||
require.NoError(t, m.DeleteDNATRule(r1))
|
||||
}
|
||||
|
||||
// TestIptablesDNAT_DoubleDeleteNoUnderflow verifies a second Delete on the same
|
||||
// rule is a no-op.
|
||||
func TestIptablesDNAT_DoubleDeleteNoUnderflow(t *testing.T) {
|
||||
m := newIptRefcountManager(t, true)
|
||||
state := m.router.ipFwdState
|
||||
|
||||
r1, err := m.AddDNATRule(iptDnatV6(9083))
|
||||
require.NoError(t, err)
|
||||
_, v6 := state.Counts()
|
||||
require.Equal(t, 1, v6)
|
||||
|
||||
require.NoError(t, m.DeleteDNATRule(r1), "first delete")
|
||||
_, v6 = state.Counts()
|
||||
require.Equal(t, 0, v6)
|
||||
|
||||
require.NoError(t, m.DeleteDNATRule(r1), "second delete must be no-op")
|
||||
_, v6 = state.Counts()
|
||||
require.Equal(t, 0, v6, "double delete must not underflow")
|
||||
}
|
||||
@@ -89,7 +89,7 @@ func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
|
||||
}
|
||||
|
||||
// Share the same IP forwarding state with the v4 router, since
|
||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||
// Forwarding refcounter is per-family but shared between v4 and v6 routers.
|
||||
m.router6.ipFwdState = m.router.ipFwdState
|
||||
|
||||
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
|
||||
@@ -402,17 +402,33 @@ func (m *Manager) SetLogLevel(log.Level) {
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||
if err := m.router.ipFwdState.RequestForwarding(false); err != nil {
|
||||
return fmt.Errorf("enable IPv4 forwarding: %w", err)
|
||||
}
|
||||
// v6 only when the overlay actually has v6.
|
||||
if m.router6 == nil {
|
||||
return nil
|
||||
}
|
||||
if err := m.router.ipFwdState.RequestForwarding(true); err != nil {
|
||||
if rerr := m.router.ipFwdState.ReleaseForwarding(false); rerr != nil {
|
||||
log.Warnf("rollback v4 forwarding: %v", rerr)
|
||||
}
|
||||
return fmt.Errorf("enable IPv6 forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||
var merr *multierror.Error
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(false); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("disable IPv4 forwarding: %w", err))
|
||||
}
|
||||
return nil
|
||||
if m.router6 != nil {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(true); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("disable IPv6 forwarding: %w", err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
|
||||
@@ -101,7 +101,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
|
||||
wgIface: wgIface,
|
||||
mtu: mtu,
|
||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(wgIface.Name()),
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
@@ -763,10 +763,6 @@ func (r *router) updateState() {
|
||||
}
|
||||
|
||||
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ruleKey := rule.ID()
|
||||
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||
return rule, nil
|
||||
@@ -841,6 +837,16 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
r.rules[key] = ruleInfo.rule
|
||||
}
|
||||
|
||||
if err := r.ipFwdState.RequestForwarding(r.v6); err != nil {
|
||||
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
|
||||
log.Errorf("rollback failed: %v", rollbackErr)
|
||||
}
|
||||
for key := range rules {
|
||||
delete(r.rules, key)
|
||||
}
|
||||
return nil, fmt.Errorf("enable forwarding: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return rule, nil
|
||||
}
|
||||
@@ -861,12 +867,15 @@ func (r *router) rollbackRules(rules map[string]ruleInfo) error {
|
||||
}
|
||||
|
||||
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
ruleKey := rule.ID()
|
||||
|
||||
_, hadDNAT := r.rules[ruleKey+dnatSuffix]
|
||||
_, hadSNAT := r.rules[ruleKey+snatSuffix]
|
||||
_, hadFWD := r.rules[ruleKey+fwdSuffix]
|
||||
if !hadDNAT && !hadSNAT && !hadFWD {
|
||||
return nil
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
||||
@@ -889,6 +898,10 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
delete(r.rules, ruleKey+fwdSuffix)
|
||||
}
|
||||
|
||||
if err := r.ipFwdState.ReleaseForwarding(r.v6); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
208
client/firewall/nftables/dnat_refcount_linux_test.go
Normal file
208
client/firewall/nftables/dnat_refcount_linux_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func nftRefcountIfaceV4() *iFaceMock {
|
||||
return &iFaceMock{
|
||||
NameFunc: func() string { return "wt-refcount" },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.96.0.1"),
|
||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func nftRefcountIfaceDual() *iFaceMock {
|
||||
return &iFaceMock{
|
||||
NameFunc: func() string { return "wt-refcount" },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.96.0.1"),
|
||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||
IPv6: netip.MustParseAddr("fd00::1"),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newNftRefcountManager(t *testing.T, dual bool) *Manager {
|
||||
t.Helper()
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
var ifMock *iFaceMock
|
||||
if dual {
|
||||
ifMock = nftRefcountIfaceDual()
|
||||
} else {
|
||||
ifMock = nftRefcountIfaceV4()
|
||||
}
|
||||
m, err := Create(ifMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "create manager")
|
||||
require.NoError(t, m.Init(nil), "init manager")
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, m.Close(nil), "close manager")
|
||||
})
|
||||
return m
|
||||
}
|
||||
|
||||
func dnatV4(port uint16) fw.ForwardRule {
|
||||
return fw.ForwardRule{
|
||||
Protocol: fw.ProtocolTCP,
|
||||
DestinationPort: fw.Port{Values: []uint16{port}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.96.0.2"),
|
||||
TranslatedPort: fw.Port{Values: []uint16{80}},
|
||||
}
|
||||
}
|
||||
|
||||
func dnatV6(port uint16) fw.ForwardRule {
|
||||
return fw.ForwardRule{
|
||||
Protocol: fw.ProtocolTCP,
|
||||
DestinationPort: fw.Port{Values: []uint16{port}},
|
||||
TranslatedAddress: netip.MustParseAddr("fd00::2"),
|
||||
TranslatedPort: fw.Port{Values: []uint16{80}},
|
||||
}
|
||||
}
|
||||
|
||||
// TestNftablesDNAT_RefcountBalancedV4 verifies that Add/Delete pairs leave the
|
||||
// v4 refcount at zero.
|
||||
func TestNftablesDNAT_RefcountBalancedV4(t *testing.T) {
|
||||
m := newNftRefcountManager(t, false)
|
||||
state := m.router.ipFwdState
|
||||
|
||||
r1, err := m.AddDNATRule(dnatV4(8081))
|
||||
require.NoError(t, err, "add v4 dnat 1")
|
||||
v4, v6 := state.Counts()
|
||||
require.Equal(t, 1, v4, "v4 refcount after first add")
|
||||
require.Equal(t, 0, v6, "v6 refcount unchanged")
|
||||
|
||||
r2, err := m.AddDNATRule(dnatV4(8082))
|
||||
require.NoError(t, err, "add v4 dnat 2")
|
||||
v4, v6 = state.Counts()
|
||||
require.Equal(t, 2, v4, "v4 refcount after second add")
|
||||
require.Equal(t, 0, v6, "v6 refcount unchanged")
|
||||
|
||||
require.NoError(t, m.DeleteDNATRule(r1), "delete v4 dnat 1")
|
||||
v4, v6 = state.Counts()
|
||||
require.Equal(t, 1, v4, "v4 refcount after first delete")
|
||||
require.Equal(t, 0, v6, "v6 refcount unchanged")
|
||||
|
||||
require.NoError(t, m.DeleteDNATRule(r2), "delete v4 dnat 2")
|
||||
v4, v6 = state.Counts()
|
||||
require.Equal(t, 0, v4, "v4 refcount after second delete")
|
||||
require.Equal(t, 0, v6, "v6 refcount unchanged")
|
||||
}
|
||||
|
||||
// TestNftablesDNAT_RefcountBalancedV6 verifies the v6 path increments v6 only
|
||||
// and decrements back to zero on Delete.
|
||||
func TestNftablesDNAT_RefcountBalancedV6(t *testing.T) {
|
||||
m := newNftRefcountManager(t, true)
|
||||
require.NotNil(t, m.router6, "v6 router")
|
||||
require.Same(t, m.router.ipFwdState, m.router6.ipFwdState, "shared state")
|
||||
state := m.router.ipFwdState
|
||||
|
||||
r1, err := m.AddDNATRule(dnatV6(9091))
|
||||
require.NoError(t, err, "add v6 dnat 1")
|
||||
v4, v6 := state.Counts()
|
||||
require.Equal(t, 0, v4, "v4 refcount unchanged")
|
||||
require.Equal(t, 1, v6, "v6 refcount after first add")
|
||||
|
||||
r2, err := m.AddDNATRule(dnatV6(9092))
|
||||
require.NoError(t, err, "add v6 dnat 2")
|
||||
v4, v6 = state.Counts()
|
||||
require.Equal(t, 0, v4)
|
||||
require.Equal(t, 2, v6, "v6 refcount after second add")
|
||||
|
||||
require.NoError(t, m.DeleteDNATRule(r1), "delete v6 dnat 1")
|
||||
v4, v6 = state.Counts()
|
||||
require.Equal(t, 0, v4, "v4 refcount unchanged")
|
||||
require.Equal(t, 1, v6, "v6 refcount after first delete")
|
||||
|
||||
require.NoError(t, m.DeleteDNATRule(r2), "delete v6 dnat 2")
|
||||
v4, v6 = state.Counts()
|
||||
require.Equal(t, 0, v4)
|
||||
require.Equal(t, 0, v6, "v6 refcount after second delete")
|
||||
}
|
||||
|
||||
// TestNftablesDNAT_DuplicateAddNoLeak verifies that a duplicate Add (same
|
||||
// ForwardRule) does not double-increment the refcount.
|
||||
func TestNftablesDNAT_DuplicateAddNoLeak(t *testing.T) {
|
||||
m := newNftRefcountManager(t, true)
|
||||
state := m.router.ipFwdState
|
||||
|
||||
rule := dnatV4(8083)
|
||||
r1, err := m.AddDNATRule(rule)
|
||||
require.NoError(t, err, "add v4 dnat")
|
||||
v4, _ := state.Counts()
|
||||
require.Equal(t, 1, v4)
|
||||
|
||||
// duplicate add: same rule ID, must be a no-op for the refcount.
|
||||
_, err = m.AddDNATRule(rule)
|
||||
require.NoError(t, err, "duplicate add")
|
||||
v4, _ = state.Counts()
|
||||
require.Equal(t, 1, v4, "duplicate add must not increment")
|
||||
|
||||
require.NoError(t, m.DeleteDNATRule(r1), "delete v4 dnat")
|
||||
v4, _ = state.Counts()
|
||||
require.Equal(t, 0, v4, "single delete must drop to zero")
|
||||
}
|
||||
|
||||
// TestNftablesDNAT_DeleteMissingNoUnderflow verifies deleting a rule that was
|
||||
// never added does not underflow the refcount.
|
||||
func TestNftablesDNAT_DeleteMissingNoUnderflow(t *testing.T) {
|
||||
m := newNftRefcountManager(t, true)
|
||||
state := m.router.ipFwdState
|
||||
|
||||
// Construct a Rule reference for something never added. The router stores
|
||||
// rules by ID(), and DeleteDNATRule looks them up in r.rules; a missing
|
||||
// entry must be a no-op rather than calling Release.
|
||||
phantom := dnatV4(8099)
|
||||
require.NoError(t, m.DeleteDNATRule(&phantom), "delete missing v4 dnat")
|
||||
v4, v6 := state.Counts()
|
||||
require.Equal(t, 0, v4, "v4 refcount unaffected by missing delete")
|
||||
require.Equal(t, 0, v6, "v6 refcount unaffected")
|
||||
|
||||
phantom6 := dnatV6(9099)
|
||||
require.NoError(t, m.DeleteDNATRule(&phantom6), "delete missing v6 dnat")
|
||||
v4, v6 = state.Counts()
|
||||
require.Equal(t, 0, v4)
|
||||
require.Equal(t, 0, v6, "v6 refcount unaffected by missing delete")
|
||||
|
||||
// And after a phantom delete, a real add still results in count=1.
|
||||
r1, err := m.AddDNATRule(dnatV4(8100))
|
||||
require.NoError(t, err, "add v4 dnat after phantom delete")
|
||||
v4, _ = state.Counts()
|
||||
require.Equal(t, 1, v4, "real add still increments after phantom delete")
|
||||
require.NoError(t, m.DeleteDNATRule(r1))
|
||||
}
|
||||
|
||||
// TestNftablesDNAT_DoubleDeleteNoUnderflow verifies that deleting the same rule
|
||||
// twice does not underflow the refcount (the second delete is a no-op).
|
||||
func TestNftablesDNAT_DoubleDeleteNoUnderflow(t *testing.T) {
|
||||
m := newNftRefcountManager(t, true)
|
||||
state := m.router.ipFwdState
|
||||
|
||||
r1, err := m.AddDNATRule(dnatV6(9093))
|
||||
require.NoError(t, err)
|
||||
_, v6 := state.Counts()
|
||||
require.Equal(t, 1, v6)
|
||||
|
||||
require.NoError(t, m.DeleteDNATRule(r1), "first delete")
|
||||
_, v6 = state.Counts()
|
||||
require.Equal(t, 0, v6)
|
||||
|
||||
require.NoError(t, m.DeleteDNATRule(r1), "second delete must be no-op")
|
||||
_, v6 = state.Counts()
|
||||
require.Equal(t, 0, v6, "double delete must not underflow")
|
||||
}
|
||||
@@ -105,8 +105,8 @@ func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mt
|
||||
return fmt.Errorf("create v6 router: %w", err)
|
||||
}
|
||||
|
||||
// Share the same IP forwarding state with the v4 router, since
|
||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||
// Share the per-family forwarding refcounter with the v4 router so a v4
|
||||
// rule and a v6 rule against the same state machine cooperate cleanly.
|
||||
m.router6.ipFwdState = m.router.ipFwdState
|
||||
|
||||
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
|
||||
@@ -530,17 +530,33 @@ func (m *Manager) SetLogLevel(log.Level) {
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||
if err := m.router.ipFwdState.RequestForwarding(false); err != nil {
|
||||
return fmt.Errorf("enable IPv4 forwarding: %w", err)
|
||||
}
|
||||
// v6 only when the overlay actually has v6.
|
||||
if m.router6 == nil {
|
||||
return nil
|
||||
}
|
||||
if err := m.router.ipFwdState.RequestForwarding(true); err != nil {
|
||||
if rerr := m.router.ipFwdState.ReleaseForwarding(false); rerr != nil {
|
||||
log.Warnf("rollback v4 forwarding: %v", rerr)
|
||||
}
|
||||
return fmt.Errorf("enable IPv6 forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||
var merr *multierror.Error
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(false); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("disable IPv4 forwarding: %w", err))
|
||||
}
|
||||
return nil
|
||||
if m.router6 != nil {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(true); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("disable IPv6 forwarding: %w", err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
|
||||
@@ -93,7 +93,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
|
||||
wgIface: wgIface,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(wgIface.Name()),
|
||||
mtu: mtu,
|
||||
}
|
||||
|
||||
@@ -1550,10 +1550,6 @@ func (r *router) refreshRulesMap() error {
|
||||
}
|
||||
|
||||
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ruleKey := rule.ID()
|
||||
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||
return rule, nil
|
||||
@@ -1564,7 +1560,18 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
// Request forwarding before queueing rules: addDnatRedirect/addDnatMasq
|
||||
// buffer netlink messages on r.conn that the next caller's Flush would
|
||||
// commit if we returned without flushing them ourselves.
|
||||
v6 := r.af.tableFamily == nftables.TableFamilyIPv6
|
||||
if err := r.ipFwdState.RequestForwarding(v6); err != nil {
|
||||
return nil, fmt.Errorf("enable forwarding: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil {
|
||||
if rerr := r.ipFwdState.ReleaseForwarding(v6); rerr != nil {
|
||||
log.Warnf("rollback forwarding refcount: %v", rerr)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1576,6 +1583,11 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
// TODO: find chains with drop policies and add rules there
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
if rerr := r.ipFwdState.ReleaseForwarding(v6); rerr != nil {
|
||||
log.Warnf("rollback forwarding refcount: %v", rerr)
|
||||
}
|
||||
delete(r.rules, ruleKey+dnatSuffix)
|
||||
delete(r.rules, ruleKey+snatSuffix)
|
||||
return nil, fmt.Errorf("flush rules: %w", err)
|
||||
}
|
||||
|
||||
@@ -1778,16 +1790,18 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey
|
||||
}
|
||||
|
||||
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
ruleKey := rule.ID()
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
_, hadDNAT := r.rules[ruleKey+dnatSuffix]
|
||||
_, hadSNAT := r.rules[ruleKey+snatSuffix]
|
||||
if !hadDNAT && !hadSNAT {
|
||||
return nil
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
var needsFlush bool
|
||||
|
||||
@@ -1824,6 +1838,10 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
delete(r.rules, ruleKey+snatSuffix)
|
||||
}
|
||||
|
||||
if err := r.ipFwdState.ReleaseForwarding(r.af.tableFamily == nftables.TableFamilyIPv6); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
|
||||
@@ -844,6 +844,10 @@ func collectSysctls() string {
|
||||
[]string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"},
|
||||
listInterfaceSysctls("ipv4", "src_valid_mark")...,
|
||||
))
|
||||
writeSysctlGroup(&builder, "accept_ra", append(
|
||||
[]string{"net.ipv6.conf.all.accept_ra", "net.ipv6.conf.default.accept_ra"},
|
||||
listInterfaceSysctls("ipv6", "accept_ra")...,
|
||||
))
|
||||
writeSysctlGroup(&builder, "conntrack", []string{
|
||||
"net.netfilter.nf_conntrack_acct",
|
||||
"net.netfilter.nf_conntrack_tcp_loose",
|
||||
|
||||
@@ -2,54 +2,109 @@ package ipfwdstate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
// IPForwardingState is a struct that keeps track of the IP forwarding state.
|
||||
// todo: read initial state of the IP forwarding from the system and reset the state based on it.
|
||||
// todo: separate v4/v6 forwarding state, since the sysctls are independent
|
||||
// (net.ipv4.ip_forward vs net.ipv6.conf.all.forwarding). Currently the nftables
|
||||
// manager shares one instance between both routers, which works only because
|
||||
// EnableIPForwarding enables both sysctls in a single call.
|
||||
// IPForwardingState tracks v4 and v6 IP-forwarding sysctl enables with
|
||||
// independent refcounts so a v4-only routing setup doesn't flip v6 sysctls.
|
||||
type IPForwardingState struct {
|
||||
enabledCounter int
|
||||
mu sync.Mutex
|
||||
|
||||
v4Count int
|
||||
v6Count int
|
||||
|
||||
wgIfaceName string
|
||||
v6Saved map[string]int
|
||||
}
|
||||
|
||||
func NewIPForwardingState() *IPForwardingState {
|
||||
return &IPForwardingState{}
|
||||
func NewIPForwardingState(wgIfaceName string) *IPForwardingState {
|
||||
return &IPForwardingState{wgIfaceName: wgIfaceName}
|
||||
}
|
||||
|
||||
func (f *IPForwardingState) RequestForwarding() error {
|
||||
if f.enabledCounter != 0 {
|
||||
f.enabledCounter++
|
||||
return nil
|
||||
}
|
||||
// Counts returns the current v4 and v6 refcounts. Intended for diagnostics
|
||||
// and tests.
|
||||
func (f *IPForwardingState) Counts() (v4, v6 int) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.v4Count, f.v6Count
|
||||
}
|
||||
|
||||
if err := systemops.EnableIPForwarding(); err != nil {
|
||||
return fmt.Errorf("failed to enable IP forwarding with sysctl: %w", err)
|
||||
}
|
||||
f.enabledCounter = 1
|
||||
log.Info("IP forwarding enabled")
|
||||
// RequestForwarding enables the family's forwarding sysctl on first request.
|
||||
func (f *IPForwardingState) RequestForwarding(v6 bool) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
if v6 {
|
||||
return f.requestV6()
|
||||
}
|
||||
return f.requestV4()
|
||||
}
|
||||
|
||||
// ReleaseForwarding decrements the family counter. The last v6 release restores
|
||||
// what enable captured. v4 stays on: net.ipv4.ip_forward is co-owned by other
|
||||
// tooling (docker, k8s, libvirt).
|
||||
func (f *IPForwardingState) ReleaseForwarding(v6 bool) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
if v6 {
|
||||
return f.releaseV6()
|
||||
}
|
||||
f.releaseV4()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *IPForwardingState) ReleaseForwarding() error {
|
||||
if f.enabledCounter == 0 {
|
||||
return nil
|
||||
func (f *IPForwardingState) requestV4() error {
|
||||
if f.v4Count == 0 {
|
||||
if err := systemops.EnableV4IPForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IPv4 forwarding: %w", err)
|
||||
}
|
||||
log.Info("IPv4 forwarding enabled")
|
||||
}
|
||||
|
||||
if f.enabledCounter > 1 {
|
||||
f.enabledCounter--
|
||||
return nil
|
||||
}
|
||||
|
||||
// if failed to disable IP forwarding we anyway decrement the counter
|
||||
f.enabledCounter = 0
|
||||
|
||||
// todo call systemops.DisableIPForwarding()
|
||||
f.v4Count++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *IPForwardingState) releaseV4() {
|
||||
if f.v4Count > 0 {
|
||||
f.v4Count--
|
||||
}
|
||||
}
|
||||
|
||||
func (f *IPForwardingState) requestV6() error {
|
||||
if f.v6Count == 0 {
|
||||
saved, err := systemops.EnableV6IPForwarding(f.wgIfaceName)
|
||||
if err != nil {
|
||||
if rerr := systemops.DisableV6IPForwarding(saved); rerr != nil {
|
||||
log.Warnf("rollback partial v6 sysctls: %v", rerr)
|
||||
}
|
||||
return fmt.Errorf("enable IPv6 forwarding: %w", err)
|
||||
}
|
||||
f.v6Saved = saved
|
||||
log.Info("IPv6 forwarding enabled")
|
||||
}
|
||||
f.v6Count++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *IPForwardingState) releaseV6() error {
|
||||
if f.v6Count == 0 {
|
||||
return nil
|
||||
}
|
||||
f.v6Count--
|
||||
if f.v6Count > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
saved := f.v6Saved
|
||||
f.v6Saved = nil
|
||||
if err := systemops.DisableV6IPForwarding(saved); err != nil {
|
||||
return fmt.Errorf("disable IPv6 forwarding: %w", err)
|
||||
}
|
||||
log.Info("IPv6 forwarding disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -32,8 +32,17 @@ func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func EnableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
func EnableV4IPForwarding() error {
|
||||
log.Infof("Enable IPv4 forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func EnableV6IPForwarding(string) (map[string]int, error) {
|
||||
log.Infof("Enable IPv6 forwarding is not implemented on %s", runtime.GOOS)
|
||||
return map[string]int{}, nil
|
||||
}
|
||||
|
||||
func DisableV6IPForwarding(map[string]int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -58,8 +58,17 @@ func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func EnableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
func EnableV4IPForwarding() error {
|
||||
log.Infof("Enable IPv4 forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func EnableV6IPForwarding(string) (map[string]int, error) {
|
||||
log.Infof("Enable IPv6 forwarding is not implemented on %s", runtime.GOOS)
|
||||
return map[string]int{}, nil
|
||||
}
|
||||
|
||||
func DisableV6IPForwarding(map[string]int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -763,13 +763,10 @@ func flushRoutes(tableID, family int) error {
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
func EnableIPForwarding() error {
|
||||
func EnableV4IPForwarding() error {
|
||||
if _, err := sysctl.Set(ipv4ForwardingPath, 1, false); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := sysctl.Set(ipv6ForwardingPath, 1, false); err != nil {
|
||||
log.Warnf("failed to enable IPv6 forwarding: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -43,8 +43,17 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error
|
||||
return r.genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
func EnableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
func EnableV4IPForwarding() error {
|
||||
log.Infof("Enable IPv4 forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func EnableV6IPForwarding(string) (map[string]int, error) {
|
||||
log.Infof("Enable IPv6 forwarding is not implemented on %s", runtime.GOOS)
|
||||
return map[string]int{}, nil
|
||||
}
|
||||
|
||||
func DisableV6IPForwarding(map[string]int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
82
client/internal/routemanager/systemops/v6forwarding_linux.go
Normal file
82
client/internal/routemanager/systemops/v6forwarding_linux.go
Normal file
@@ -0,0 +1,82 @@
|
||||
//go:build !android
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
|
||||
)
|
||||
|
||||
const (
|
||||
// 1 (default) accepts RAs only while forwarding is off; 2 keeps RA
|
||||
// acceptance on regardless, so RA-installed host defaults survive our
|
||||
// v6 forwarding flip.
|
||||
acceptRAInterfacePath = "net.ipv6.conf.%s.accept_ra"
|
||||
acceptRAProcPathFormat = "/proc/sys/net/ipv6/conf/%s/accept_ra"
|
||||
)
|
||||
|
||||
// EnableV6IPForwarding bumps accept_ra=2 on host v6 interfaces before flipping
|
||||
// forwarding=1, so RA-installed host defaults survive. Returns the prior values
|
||||
// of sysctls we actually changed; entries already at the target are omitted.
|
||||
func EnableV6IPForwarding(wgIfaceName string) (map[string]int, error) {
|
||||
saved := map[string]int{}
|
||||
bumpAcceptRA(saved, wgIfaceName)
|
||||
|
||||
oldVal, err := sysctl.Set(ipv6ForwardingPath, 1, false)
|
||||
if err != nil {
|
||||
return saved, err
|
||||
}
|
||||
if oldVal != 1 {
|
||||
saved[ipv6ForwardingPath] = oldVal
|
||||
}
|
||||
return saved, nil
|
||||
}
|
||||
|
||||
// DisableV6IPForwarding restores what EnableV6IPForwarding captured.
|
||||
func DisableV6IPForwarding(saved map[string]int) error {
|
||||
var result *multierror.Error
|
||||
for key, value := range saved {
|
||||
if _, err := sysctl.Set(key, value, false); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("restore %s: %w", key, err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
func bumpAcceptRA(saved map[string]int, wgIfaceName string) {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
log.Warnf("list interfaces for accept_ra: %v", err)
|
||||
return
|
||||
}
|
||||
for _, intf := range interfaces {
|
||||
if intf.Name == "lo" || intf.Name == wgIfaceName {
|
||||
continue
|
||||
}
|
||||
bumpAcceptRAForInterface(saved, intf.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func bumpAcceptRAForInterface(saved map[string]int, name string) {
|
||||
key := fmt.Sprintf(acceptRAInterfacePath, name)
|
||||
// Build procfs path from name, not the dotted key: VLAN names like eth0.100.
|
||||
if _, err := os.Stat(fmt.Sprintf(acceptRAProcPathFormat, name)); err != nil {
|
||||
return
|
||||
}
|
||||
// onlyIfOne=true: leave admin overrides (0, 2) alone.
|
||||
oldVal, err := sysctl.Set(key, 2, true)
|
||||
if err != nil {
|
||||
log.Warnf("bump %s: %v", key, err)
|
||||
return
|
||||
}
|
||||
if oldVal != 2 {
|
||||
saved[key] = oldVal
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user