This commit is contained in:
İsmail
2024-11-07 19:53:54 +03:00
27 changed files with 613 additions and 310 deletions

View File

@@ -13,9 +13,10 @@ concurrency:
jobs: jobs:
test: test:
strategy: strategy:
fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ] store: [ 'mysql' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Install Go - name: Install Go
@@ -49,7 +50,7 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m -p 1 ./...
test_client_on_docker: test_client_on_docker:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04

View File

@@ -201,6 +201,8 @@ func isWellKnown(addr netip.Addr) bool {
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6 "2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4 "9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6 "2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
"128.0.0.0", "8000::", // 2nd split subnet for default routes
} }
if slices.Contains(wellKnown, addr.String()) { if slices.Contains(wellKnown, addr.String()) {

View File

@@ -352,14 +352,14 @@ func (m *aclManager) seedInitialEntries() {
func (m *aclManager) seedInitialOptionalEntries() { func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{ m.optionalEntries["FORWARD"] = []entry{
{ {
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules}, spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
position: 2, position: 2,
}, },
} }
m.optionalEntries["PREROUTING"] = []entry{ m.optionalEntries["PREROUTING"] = []entry{
{ {
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)}, spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
position: 1, position: 1,
}, },
} }

View File

@@ -18,22 +18,24 @@ import (
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) nbnet "github.com/netbirdio/netbird/util/net"
const (
ipv4Nat = "netbird-rt-nat"
) )
// constants needed to manage and create iptable rules // constants needed to manage and create iptable rules
const ( const (
tableFilter = "filter" tableFilter = "filter"
tableNat = "nat" tableNat = "nat"
tableMangle = "mangle"
chainPOSTROUTING = "POSTROUTING" chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainRTNAT = "NETBIRD-RT-NAT" chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD" chainRTFWD = "NETBIRD-RT-FWD"
chainRTPRE = "NETBIRD-RT-PRE"
routingFinalForwardJump = "ACCEPT" routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE" routingFinalNatJump = "MASQUERADE"
jumpPre = "jump-pre"
jumpNat = "jump-nat"
matchSet = "--match-set" matchSet = "--match-set"
) )
@@ -323,24 +325,25 @@ func (r *router) Reset() error {
} }
func (r *router) cleanUpDefaultForwardRules() error { func (r *router) cleanUpDefaultForwardRules() error {
err := r.cleanJumpRules() if err := r.cleanJumpRules(); err != nil {
if err != nil { return fmt.Errorf("clean jump rules: %w", err)
return err
} }
log.Debug("flushing routing related tables") log.Debug("flushing routing related tables")
for _, chain := range []string{chainRTFWD, chainRTNAT} { for _, chainInfo := range []struct {
table := r.getTableForChain(chain) chain string
table string
ok, err := r.iptablesClient.ChainExists(table, chain) }{
{chainRTFWD, tableFilter},
{chainRTNAT, tableNat},
{chainRTPRE, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil { if err != nil {
log.Errorf("failed check chain %s, error: %v", chain, err) return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
return err
} else if ok { } else if ok {
err = r.iptablesClient.ClearAndDeleteChain(table, chain) if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
if err != nil { return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
return err
} }
} }
} }
@@ -349,9 +352,16 @@ func (r *router) cleanUpDefaultForwardRules() error {
} }
func (r *router) createContainers() error { func (r *router) createContainers() error {
for _, chain := range []string{chainRTFWD, chainRTNAT} { for _, chainInfo := range []struct {
if err := r.createAndSetupChain(chain); err != nil { chain string
return fmt.Errorf("create chain %s: %w", chain, err) table string
}{
{chainRTFWD, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
} {
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
} }
} }
@@ -359,6 +369,10 @@ func (r *router) createContainers() error {
return fmt.Errorf("insert established rule: %w", err) return fmt.Errorf("insert established rule: %w", err)
} }
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add static nat rules: %w", err)
}
if err := r.addJumpRules(); err != nil { if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err) return fmt.Errorf("add jump rules: %w", err)
} }
@@ -366,6 +380,32 @@ func (r *router) createContainers() error {
return nil return nil
} }
func (r *router) addPostroutingRules() error {
// First rule for outbound masquerade
rule1 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
"!", "-o", "lo",
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
return fmt.Errorf("add outbound masquerade rule: %v", err)
}
r.rules["static-nat-outbound"] = rule1
// Second rule for return traffic masquerade
rule2 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
"-o", r.wgIface.Name(),
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
return fmt.Errorf("add return masquerade rule: %v", err)
}
r.rules["static-nat-return"] = rule2
return nil
}
func (r *router) createAndSetupChain(chain string) error { func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain) table := r.getTableForChain(chain)
@@ -377,10 +417,14 @@ func (r *router) createAndSetupChain(chain string) error {
} }
func (r *router) getTableForChain(chain string) string { func (r *router) getTableForChain(chain string) string {
if chain == chainRTNAT { switch chain {
case chainRTNAT:
return tableNat return tableNat
case chainRTPRE:
return tableMangle
default:
return tableFilter
} }
return tableFilter
} }
func (r *router) insertEstablishedRule(chain string) error { func (r *router) insertEstablishedRule(chain string) error {
@@ -398,25 +442,39 @@ func (r *router) insertEstablishedRule(chain string) error {
} }
func (r *router) addJumpRules() error { func (r *router) addJumpRules() error {
rule := []string{"-j", chainRTNAT} // Jump to NAT chain
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...) natRule := []string{"-j", chainRTNAT}
if err != nil { if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return err return fmt.Errorf("add nat jump rule: %v", err)
} }
r.rules[ipv4Nat] = rule r.rules[jumpNat] = natRule
// Jump to prerouting chain
preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add prerouting jump rule: %v", err)
}
r.rules[jumpPre] = preRule
return nil return nil
} }
func (r *router) cleanJumpRules() error { func (r *router) cleanJumpRules() error {
rule, found := r.rules[ipv4Nat] for _, ruleKey := range []string{jumpNat, jumpPre} {
if found { if rule, exists := r.rules[ruleKey]; exists {
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...) table := tableNat
if err != nil { chain := chainPOSTROUTING
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err) if ruleKey == jumpPre {
table = tableMangle
chain = chainPREROUTING
}
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err)
}
delete(r.rules, ruleKey)
} }
} }
return nil return nil
} }
@@ -424,19 +482,35 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair) ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err) return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} }
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse) markValue := nbnet.PreroutingFwmarkMasquerade
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil { if pair.Inverse {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err) markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
rule := []string{"-i", r.wgIface.Name()}
if pair.Inverse {
rule = []string{"!", "-i", r.wgIface.Name()}
}
rule = append(rule,
"-m", "conntrack",
"--ctstate", "NEW",
"-s", pair.Source.String(),
"-d", pair.Destination.String(),
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
)
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
} }
r.rules[ruleKey] = rule r.rules[ruleKey] = rule
return nil return nil
} }
@@ -444,13 +518,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair) ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err) return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} else { } else {
log.Debugf("nat rule %s not found", ruleKey) log.Debugf("marking rule %s not found", ruleKey)
} }
return nil return nil
@@ -482,16 +555,6 @@ func (r *router) updateState() {
} }
} }
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
intdir := "-i"
lointdir := "-o"
if inverse {
intdir = "-o"
lointdir = "-i"
}
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string var rule []string

View File

@@ -3,17 +3,18 @@
package iptables package iptables
import ( import (
"fmt"
"net/netip" "net/netip"
"os/exec" "os/exec"
"testing" "testing"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test" "github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func isIptablesSupported() bool { func isIptablesSupported() bool {
@@ -34,14 +35,24 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.NoError(t, manager.init(nil)) require.NoError(t, manager.init(nil))
defer func() { defer func() {
_ = manager.Reset() assert.NoError(t, manager.Reset(), "shouldn't return error")
}() }()
require.Len(t, manager.rules, 2, "should have created rules map") // Now 5 rules:
// 1. established rule in forward chain
// 2. jump rule to NAT chain
// 3. jump rule to PRE chain
// 4. static outbound masquerade rule
// 5. static return masquerade rule
require.Len(t, manager.rules, 5, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...) exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting rule should exist") require.True(t, exists, "postrouting jump rule should exist")
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
require.True(t, exists, "prerouting jump rule should exist")
pair := firewall.RouterPair{ pair := firewall.RouterPair{
ID: "abc", ID: "abc",
@@ -49,22 +60,15 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
Destination: netip.MustParsePrefix("100.100.100.0/24"), Destination: netip.MustParsePrefix("100.100.100.0/24"),
Masquerade: true, Masquerade: true,
} }
forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...) err = manager.AddNatRule(pair)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "adding NAT rule should not return error")
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset() err = manager.Reset()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
} }
func TestIptablesManager_AddNatRule(t *testing.T) { func TestIptablesManager_AddNatRule(t *testing.T) {
if !isIptablesSupported() { if !isIptablesSupported() {
t.SkipNow() t.SkipNow()
} }
@@ -79,52 +83,66 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
require.NoError(t, manager.init(nil)) require.NoError(t, manager.init(nil))
defer func() { defer func() {
err := manager.Reset() assert.NoError(t, manager.Reset(), "shouldn't return error")
if err != nil {
log.Errorf("failed to reset iptables manager: %s", err)
}
}() }()
err = manager.AddNatRule(testCase.InputPair) err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "forwarding pair should be inserted") require.NoError(t, err, "marking rule should be inserted")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) markingRule := []string{
"-i", ifaceMock.Name(),
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) "-m", "conntrack",
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) "--ctstate", "NEW",
if testCase.InputPair.Masquerade { "-s", testCase.InputPair.Source.String(),
require.True(t, exists, "nat rule should be created") "-d", testCase.InputPair.Destination.String(),
foundNatRule, foundNat := manager.rules[natRuleKey] "-j", "MARK", "--set-mark",
require.True(t, foundNat, "nat rule should exist in the map") fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[natRuleKey]
require.False(t, foundNat, "nat rule should not exist in the map")
} }
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade { if testCase.InputPair.Masquerade {
require.True(t, exists, "income nat rule should be created") require.True(t, exists, "marking rule should be created")
foundNatRule, foundNat := manager.rules[inNatRuleKey] foundRule, found := manager.rules[natRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map") require.True(t, found, "marking rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match") require.Equal(t, markingRule, foundRule, "stored marking rule should match")
} else { } else {
require.False(t, exists, "nat rule should not be created") require.False(t, exists, "marking rule should not be created")
_, foundNat := manager.rules[inNatRuleKey] _, found := manager.rules[natRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map") require.False(t, found, "marking rule should not exist in the map")
}
// Check inverse rule
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
if testCase.InputPair.Masquerade {
require.True(t, exists, "inverse marking rule should be created")
foundRule, found := manager.rules[inverseRuleKey]
require.True(t, found, "inverse marking rule should exist in the map")
require.Equal(t, inverseMarkingRule, foundRule, "stored inverse marking rule should match")
} else {
require.False(t, exists, "inverse marking rule should not be created")
_, found := manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
} }
}) })
} }
} }
func TestIptablesManager_RemoveNatRule(t *testing.T) { func TestIptablesManager_RemoveNatRule(t *testing.T) {
if !isIptablesSupported() { if !isIptablesSupported() {
t.SkipNow() t.SkipNow()
} }
@@ -137,42 +155,52 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil)) require.NoError(t, manager.init(nil))
defer func() { defer func() {
_ = manager.Reset() assert.NoError(t, manager.Reset(), "shouldn't return error")
}() }()
require.NoError(t, err, "shouldn't return error") err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule without error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveNatRule(testCase.InputPair) err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) markingRule := []string{
require.False(t, exists, "nat rule should not exist") "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", testCase.InputPair.Source.String(),
"-d", testCase.InputPair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "marking rule should not exist")
_, found := manager.rules[natRuleKey] _, found := manager.rules[natRuleKey]
require.False(t, found, "nat rule should exist in the manager map") require.False(t, found, "marking rule should not exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) // Check inverse rule removal
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) inversePair := firewall.GetInversePair(testCase.InputPair)
require.False(t, exists, "income nat rule should not exist") inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
_, found = manager.rules[inNatRuleKey] exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.False(t, found, "income nat rule should exist in the manager map") require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "inverse marking rule should not exist")
_, found = manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
}) })
} }
} }

View File

@@ -17,6 +17,7 @@ import (
const ( const (
ForwardingFormatPrefix = "netbird-fwd-" ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t" ForwardingFormat = "netbird-fwd-%s-%t"
PreroutingFormat = "netbird-prerouting-%s-%t"
NatFormat = "netbird-nat-%s-%t" NatFormat = "netbird-nat-%s-%t"
) )

View File

@@ -520,7 +520,7 @@ func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
}, },
&expr.Immediate{ &expr.Immediate{
Register: 1, Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
}, },
&expr.Meta{ &expr.Meta{
Key: expr.MetaKeyMARK, Key: expr.MetaKeyMARK,
@@ -543,7 +543,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictJump, Kind: expr.VerdictJump,

View File

@@ -21,6 +21,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@@ -124,7 +125,6 @@ func (r *router) createContainers() error {
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
prio := *nftables.ChainPriorityNATSource - 1 prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat, Name: chainNameRoutingNat,
Table: r.workTable, Table: r.workTable,
@@ -133,6 +133,21 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT, Type: nftables.ChainTypeNAT,
}) })
// Chain is created by acl manager
// TODO: move creation to a common place
r.chains[chainNamePrerouting] = &nftables.Chain{
Name: chainNamePrerouting,
Table: r.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
}
// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err)
}
if err := r.acceptForwardRules(); err != nil { if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err) log.Errorf("failed to add accept rules for the forward chain: %s", err)
} }
@@ -422,59 +437,149 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source) sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination) destExp := generateCIDRMatcherExpressions(false, pair.Destination)
dir := expr.MetaKeyIIFNAME op := expr.CmpOpEq
notDir := expr.MetaKeyOIFNAME
if pair.Inverse { if pair.Inverse {
dir = expr.MetaKeyOIFNAME op = expr.CmpOpNeq
notDir = expr.MetaKeyIIFNAME
} }
lo := ifname("lo")
intf := ifname(r.wgIface.Name())
exprs := []expr.Any{ exprs := []expr.Any{
&expr.Meta{ // We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
Key: dir, // Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1, Register: 1,
}, },
&expr.Cmp{ &expr.Bitwise{
Op: expr.CmpOpEq, SourceRegister: 1,
Register: 1, DestRegister: 1,
Data: intf, Len: 4,
}, Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: notDir,
Register: 1,
}, },
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpNeq, Op: expr.CmpOpNeq,
Register: 1, Register: 1,
Data: lo, Data: []byte{0, 0, 0, 0},
},
// interface matching
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: op,
Register: 1,
Data: ifname(r.wgIface.Name()),
}, },
} }
exprs = append(exprs, sourceExp...) exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...) exprs = append(exprs, destExp...)
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
exprs = append(exprs, exprs = append(exprs,
&expr.Counter{}, &expr.Masq{}, &expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(markValue),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
) )
ruleKey := firewall.GenKey(firewall.NatFormat, pair) ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if _, exists := r.rules[ruleKey]; exists { if _, exists := r.rules[ruleKey]; exists {
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove routing rule: %w", err) return fmt.Errorf("remove prerouting rule: %w", err)
} }
} }
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable, Table: r.workTable,
Chain: r.chains[chainNameRoutingNat], Chain: r.chains[chainNamePrerouting],
Exprs: exprs, Exprs: exprs,
UserData: []byte(ruleKey), UserData: []byte(ruleKey),
}) })
return nil
}
// addPostroutingRules adds the masquerade rules
func (r *router) addPostroutingRules() error {
// First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{
// Match on the first fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
})
// Second masquerade rule for traffic going out through WireGuard interface
exprs2 := []expr.Any{
// Match on the second fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
},
// Match WireGuard interface
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2,
})
return nil return nil
} }
@@ -723,18 +828,18 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// RemoveNatRule removes a nftables rule pair from nat chains // RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err) return fmt.Errorf("remove prerouting rule: %w", err)
} }
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err) return fmt.Errorf("remove inverse prerouting rule: %w", err)
} }
if err := r.removeLegacyRouteRule(pair); err != nil { if err := r.removeLegacyRouteRule(pair); err != nil {
@@ -749,21 +854,20 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return nil return nil
} }
// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
func (r *router) removeNatRule(pair firewall.RouterPair) error { func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair) ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule) err := r.conn.DelRule(rule)
if err != nil { if err != nil {
return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination) log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} else { } else {
log.Debugf("nftables: nat rule %s not found", ruleKey) log.Debugf("nftables: prerouting rule %s not found", ruleKey)
} }
return nil return nil

View File

@@ -10,6 +10,7 @@ import (
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -32,100 +33,87 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
t.Skip("nftables not supported on this OS") t.Skip("nftables not supported on this OS")
} }
table, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases { for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(table, ifaceMock) // need fw manager to init both acl mgr and router for all chains to be present
require.NoError(t, err, "failed to create router") manager, err := Create(ifaceMock)
require.NoError(t, manager.init(table)) t.Cleanup(func() {
require.NoError(t, manager.Reset(nil))
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer func(manager *router) { rtr := manager.router
require.NoError(t, manager.Reset(), "failed to reset rules") err = rtr.AddNatRule(testCase.InputPair)
}(manager)
require.NoError(t, err, "shouldn't return error")
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "pair should be inserted") require.NoError(t, err, "pair should be inserted")
defer func(manager *router, pair firewall.RouterPair) { t.Cleanup(func() {
require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule") require.NoError(t, rtr.RemoveNatRule(testCase.InputPair), "failed to remove rule")
}(manager, testCase.InputPair) })
if testCase.InputPair.Masquerade { if testCase.InputPair.Masquerade {
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) // Build expected expressions for connection tracking
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) conntrackExprs := []expr.Any{
testingExpression := append(sourceExp, destExp...) //nolint:gocritic &expr.Ct{
testingExpression = append(testingExpression, Key: expr.CtKeySTATE,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
// Build interface matching expression
ifaceExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: ifname(ifaceMock.Name()), Data: ifname(ifaceMock.Name()),
}, },
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
found = 1
}
}
} }
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
if testCase.InputPair.Masquerade { // Build CIDR matching expressions
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
testingExpression = append(testingExpression,
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
)
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) // Combine all expressions in the correct order
// nolint:gocritic
testingExpression := append(conntrackExprs, ifaceExprs...)
testingExpression = append(testingExpression, sourceExp...)
testingExpression = append(testingExpression, destExp...)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := 0 found := 0
for _, chain := range manager.chains { for _, chain := range rtr.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain) if chain.Name == chainNamePrerouting {
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
for _, rule := range rules { require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey { for _, rule := range rules {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match") if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = 1 // Compare expressions up to the mark setting expressions
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
found = 1
}
} }
} }
} }
require.Equal(t, 1, found, "should find at least 1 rule to test") require.Equal(t, 1, found, "should find at least 1 rule in prerouting chain")
} }
}) })
} }
} }
@@ -135,68 +123,66 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
t.Skip("nftables not supported on this OS") t.Skip("nftables not supported on this OS")
} }
table, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases { for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(table, ifaceMock) manager, err := Create(ifaceMock)
require.NoError(t, err, "failed to create router") t.Cleanup(func() {
require.NoError(t, manager.init(table)) require.NoError(t, manager.Reset(nil))
nftablesTestingClient := &nftables.Conn{}
defer func(manager *router) {
require.NoError(t, manager.Reset(), "failed to reset rules")
}(manager)
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(natRuleKey),
}) })
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source) rtr := manager.router
destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination)
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic // First add the NAT rule using the router's method
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule")
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{ // Verify the rule was added
Table: manager.workTable, natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
Chain: manager.chains[chainNameRoutingNat], found := false
Exprs: natExp, rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
UserData: []byte(inNatRuleKey), require.NoError(t, err, "should list rules")
}) for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
err = nftablesTestingClient.Flush() found = true
require.NoError(t, err, "shouldn't return error") break
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
}
} }
} }
require.True(t, found, "NAT rule should exist before removal")
// Now remove the rule
err = rtr.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error when removing rule")
// Verify the rule was removed
found = false
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules after removal")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
}
}
require.False(t, found, "NAT rule should not exist after removal")
// Verify the static postrouting rules still exist
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameRoutingNat])
require.NoError(t, err, "should list postrouting rules")
foundCounter := false
for _, rule := range rules {
for _, e := range rule.Exprs {
if _, ok := e.(*expr.Counter); ok {
foundCounter = true
break
}
}
if foundCounter {
break
}
}
require.True(t, foundCounter, "static postrouting rule should remain")
}) })
} }
} }

View File

@@ -626,6 +626,8 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
s.oauthAuthFlow = oauthAuthFlow{}
if s.actCancel == nil { if s.actCancel == nil {
return nil, fmt.Errorf("service is not up") return nil, fmt.Errorf("service is not up")
} }

2
go.mod
View File

@@ -60,7 +60,7 @@ require (
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -524,8 +524,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c= github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 h1:L8mNd3tBxMdnQNxMNJ+/EiwHwizNOMy8/nHLVGNfjpg=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=

View File

@@ -308,7 +308,7 @@ type UserInfo struct {
IsServiceUser bool `json:"is_service_user"` IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"` IsBlocked bool `json:"is_blocked"`
NonDeletable bool `json:"non_deletable"` NonDeletable bool `json:"non_deletable"`
LastLogin time.Time `json:"last_login" gorm:"default:null"` LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"` Issued string `json:"issued"`
IntegrationReference integration_reference.IntegrationReference `json:"-"` IntegrationReference integration_reference.IntegrationReference `json:"-"`
Permissions UserPermissions `json:"permissions"` Permissions UserPermissions `json:"permissions"`
@@ -1249,7 +1249,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
log.Errorf("failed getting account %s expiring peers", account.Id) log.Errorf("failed getting account %s expiring peers", accountID)
return account.GetNextInactivePeerExpiration() return account.GetNextInactivePeerExpiration()
} }

View File

@@ -29,14 +29,18 @@ import (
) )
type MocIntegratedValidator struct { type MocIntegratedValidator struct {
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
} }
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil return nil
} }
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, nil if a.ValidatePeerFunc != nil {
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
}
return update, false, nil
} }
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{}) validatedPeers := make(map[string]struct{})

View File

@@ -11,7 +11,7 @@ import (
// IntegratedValidator interface exists to avoid the circle dependencies // IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface { type IntegratedValidator interface {
ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error)

View File

@@ -453,8 +453,8 @@ func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtr
return nil return nil
} }
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, nil return update, false, nil
} }
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {

View File

@@ -189,7 +189,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID) return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID)
} }
update, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) var requiresPeerUpdates bool
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -265,7 +266,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return nil, err return nil, err
} }
if peerLabelUpdated { if peerLabelUpdated || requiresPeerUpdates {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
} }

View File

@@ -40,7 +40,7 @@ type Peer struct {
InactivityExpirationEnabled bool InactivityExpirationEnabled bool
// LastLogin the time when peer performed last login operation // LastLogin the time when peer performed last login operation
LastLogin time.Time `gorm:"default:null"` LastLogin time.Time
// CreatedAt records the time the peer was created // CreatedAt records the time the peer was created
CreatedAt time.Time CreatedAt time.Time
// Indicate ephemeral peer attribute // Indicate ephemeral peer attribute
@@ -51,7 +51,7 @@ type Peer struct {
type PeerStatus struct { //nolint:revive type PeerStatus struct { //nolint:revive
// LastSeen is the last time peer was connected to the management service // LastSeen is the last time peer was connected to the management service
LastSeen time.Time `gorm:"default:null"` LastSeen time.Time
// Connected indicates whether peer is connected to the management service or not // Connected indicates whether peer is connected to the management service or not
Connected bool Connected bool
// LoginExpired // LoginExpired

View File

@@ -22,6 +22,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -1398,6 +1399,50 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
} }
}) })
t.Run("validator requires update", func(t *testing.T) {
requireUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *nbAccount.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, true, nil
}
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireUpdateFunc}
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
t.Run("validator requires no update", func(t *testing.T) {
requireNoUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *nbAccount.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, false, nil
}
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc}
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Adding peer to group linked with policy should update account peers and send peer update // Adding peer to group linked with policy should update account peers and send peer update
t.Run("adding peer to group linked with policy", func(t *testing.T) { t.Run("adding peer to group linked with policy", func(t *testing.T) {
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{

View File

@@ -35,7 +35,7 @@ type PersonalAccessToken struct {
// scope could be added in future // scope could be added in future
CreatedBy string CreatedBy string
CreatedAt time.Time CreatedAt time.Time
LastUsed time.Time `gorm:"default:null"` LastUsed time.Time
} }
func (t *PersonalAccessToken) Copy() *PersonalAccessToken { func (t *PersonalAccessToken) Copy() *PersonalAccessToken {

View File

@@ -87,7 +87,7 @@ type SetupKey struct {
// UsedTimes indicates how many times the key was used // UsedTimes indicates how many times the key was used
UsedTimes int UsedTimes int
// LastUsed last time the key was used for peer registration // LastUsed last time the key was used for peer registration
LastUsed time.Time `gorm:"default:null"` LastUsed time.Time
// AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register // AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register
AutoGroups []string `gorm:"serializer:json"` AutoGroups []string `gorm:"serializer:json"`
// UsageLimit indicates the number of times this key can be used to enroll a machine. // UsageLimit indicates the number of times this key can be used to enroll a machine.

View File

@@ -70,9 +70,21 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr
if err != nil { if err != nil {
conns = runtime.NumCPU() conns = runtime.NumCPU()
} }
if storeEngine == SqliteStoreEngine {
if err == nil {
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
}
conns = 1
}
sql.SetMaxOpenConns(conns) sql.SetMaxOpenConns(conns)
log.Infof("Set max open db connections to %d", conns) log.WithContext(ctx).Infof("Set max open db connections to %d", conns)
if storeEngine == MysqlStoreEngine {
sql.SetConnMaxLifetime(120)
}
if err := migrate(ctx, db); err != nil { if err := migrate(ctx, db); err != nil {
return nil, fmt.Errorf("migrate: %w", err) return nil, fmt.Errorf("migrate: %w", err)
@@ -1048,7 +1060,7 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
// NewMysqlStore creates a new MySQL store. // NewMysqlStore creates a new MySQL store.
func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
db, err := gorm.Open(mysql.Open(dsn + "?charset=utf8&parseTime=True"), getGormConfig()) db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True"), getGormConfig())
if err != nil { if err != nil {
return nil, err return nil, err
} }

40
management/server/testdata/mysql.cnf vendored Normal file
View File

@@ -0,0 +1,40 @@
# For advice on how to change settings please see
# http://dev.mysql.com/doc/refman/8.1/en/server-configuration-defaults.html
[mysqld]
#
# Remove leading # and set to the amount of RAM for the most important data
# cache in MySQL. Start at 70% of total RAM for dedicated server, else 10%.
# innodb_buffer_pool_size = 128M
#
# Remove leading # to turn on a very important data integrity option: logging
# changes to the binary log between backups.
# log_bin
#
# Remove leading # to set options mainly useful for reporting servers.
# The server defaults are faster for transactions and fast SELECTs.
# Adjust sizes as needed, experiment to find the optimal values.
# join_buffer_size = 128M
# sort_buffer_size = 2M
# read_rnd_buffer_size = 2M
# Remove leading # to revert to previous value for default_authentication_plugin,
# this will increase compatibility with older clients. For background, see:
# https://dev.mysql.com/doc/refman/8.1/en/server-system-variables.html#sysvar_default_authentication_plugin
# default-authentication-plugin=mysql_native_password
host_cache_size=0
skip-name-resolve
datadir=/var/lib/mysql
socket=/var/run/mysqld/mysqld.sock
secure-file-priv=/var/lib/mysql-files
user=mysql
sql_mode=""
wait_timeout=300
interactive_timeout=300
innodb_flush_log_at_trx_commit=2
pid-file=/var/run/mysqld/mysqld.pid
[client]
socket=/var/run/mysqld/mysqld.sock
!includedir /etc/mysql/conf.d/

View File

@@ -35,9 +35,12 @@ func CreatePGDB() (func(), error) {
func CreateMyDB() (func(), error) { func CreateMyDB() (func(), error) {
mysqlConfigPath := "../../management/server/testdata/mysql.cnf"
ctx := context.Background() ctx := context.Background()
c, err := mysql.Run(ctx, c, err := mysql.Run(ctx,
"mysql:8.0.40", "mysql:8.0.40",
mysql.WithConfigFile(mysqlConfigPath),
mysql.WithDatabase("netbird"), mysql.WithDatabase("netbird"),
mysql.WithUsername("netbird"), mysql.WithUsername("netbird"),
mysql.WithPassword("mysql"), mysql.WithPassword("mysql"),

View File

@@ -74,7 +74,7 @@ type User struct {
// Blocked indicates whether the user is blocked. Blocked users can't use the system. // Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool Blocked bool
// LastLogin is the last time the user logged in to IdP // LastLogin is the last time the user logged in to IdP
LastLogin time.Time `gorm:"default:null"` LastLogin time.Time
// CreatedAt records the time the user was created // CreatedAt records the time the user was created
CreatedAt time.Time CreatedAt time.Time

View File

@@ -63,13 +63,14 @@ func (l *Listener) Shutdown(ctx context.Context) error {
} }
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
connRemoteAddr := remoteAddr(r)
wsConn, err := websocket.Accept(w, r, nil) wsConn, err := websocket.Accept(w, r, nil)
if err != nil { if err != nil {
log.Errorf("failed to accept ws connection from %s: %s", r.RemoteAddr, err) log.Errorf("failed to accept ws connection from %s: %s", connRemoteAddr, err)
return return
} }
rAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr) rAddr, err := net.ResolveTCPAddr("tcp", connRemoteAddr)
if err != nil { if err != nil {
err = wsConn.Close(websocket.StatusInternalError, "internal error") err = wsConn.Close(websocket.StatusInternalError, "internal error")
if err != nil { if err != nil {
@@ -90,3 +91,10 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
conn := NewConn(wsConn, lAddr, rAddr) conn := NewConn(wsConn, lAddr, rAddr)
l.acceptFn(conn) l.acceptFn(conn)
} }
func remoteAddr(r *http.Request) string {
if r.Header.Get("X-Real-Ip") == "" || r.Header.Get("X-Real-Port") == "" {
return r.RemoteAddr
}
return fmt.Sprintf("%s:%s", r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port"))
}

View File

@@ -11,8 +11,11 @@ import (
const ( const (
// NetbirdFwmark is the fwmark value used by Netbird via wireguard // NetbirdFwmark is the fwmark value used by Netbird via wireguard
NetbirdFwmark = 0x1BD00 NetbirdFwmark = 0x1BD00
PreroutingFwmark = 0x1BD01
PreroutingFwmarkRedirected = 0x1BD01
PreroutingFwmarkMasquerade = 0x1BD11
PreroutingFwmarkMasqueradeReturn = 0x1BD12
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
) )