diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go index 208e74d53..7ebe0442d 100644 --- a/client/anonymize/anonymize.go +++ b/client/anonymize/anonymize.go @@ -201,6 +201,8 @@ func isWellKnown(addr netip.Addr) bool { "2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6 "9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4 "2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6 + + "128.0.0.0", "8000::", // 2nd split subnet for default routes } if slices.Contains(wellKnown, addr.String()) { diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 5cd69245b..1c0527ebc 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -352,14 +352,14 @@ func (m *aclManager) seedInitialEntries() { func (m *aclManager) seedInitialOptionalEntries() { m.optionalEntries["FORWARD"] = []entry{ { - spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules}, + spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules}, position: 2, }, } m.optionalEntries["PREROUTING"] = []entry{ { - spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)}, + spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)}, position: 1, }, } diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 9b75640b4..d067a3e7b 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -18,22 +18,24 @@ import ( "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/statemanager" -) - -const ( - ipv4Nat = "netbird-rt-nat" + nbnet "github.com/netbirdio/netbird/util/net" ) // constants needed to manage and create iptable rules const ( tableFilter = "filter" tableNat = "nat" + tableMangle = "mangle" chainPOSTROUTING = "POSTROUTING" + chainPREROUTING = "PREROUTING" chainRTNAT = "NETBIRD-RT-NAT" chainRTFWD = "NETBIRD-RT-FWD" + chainRTPRE = "NETBIRD-RT-PRE" routingFinalForwardJump = "ACCEPT" routingFinalNatJump = "MASQUERADE" + jumpPre = "jump-pre" + jumpNat = "jump-nat" matchSet = "--match-set" ) @@ -323,24 +325,25 @@ func (r *router) Reset() error { } func (r *router) cleanUpDefaultForwardRules() error { - err := r.cleanJumpRules() - if err != nil { - return err + if err := r.cleanJumpRules(); err != nil { + return fmt.Errorf("clean jump rules: %w", err) } log.Debug("flushing routing related tables") - for _, chain := range []string{chainRTFWD, chainRTNAT} { - table := r.getTableForChain(chain) - - ok, err := r.iptablesClient.ChainExists(table, chain) + for _, chainInfo := range []struct { + chain string + table string + }{ + {chainRTFWD, tableFilter}, + {chainRTNAT, tableNat}, + {chainRTPRE, tableMangle}, + } { + ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) if err != nil { - log.Errorf("failed check chain %s, error: %v", chain, err) - return err + return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) } else if ok { - err = r.iptablesClient.ClearAndDeleteChain(table, chain) - if err != nil { - log.Errorf("failed cleaning chain %s, error: %v", chain, err) - return err + if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil { + return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) } } } @@ -349,9 +352,16 @@ func (r *router) cleanUpDefaultForwardRules() error { } func (r *router) createContainers() error { - for _, chain := range []string{chainRTFWD, chainRTNAT} { - if err := r.createAndSetupChain(chain); err != nil { - return fmt.Errorf("create chain %s: %w", chain, err) + for _, chainInfo := range []struct { + chain string + table string + }{ + {chainRTFWD, tableFilter}, + {chainRTPRE, tableMangle}, + {chainRTNAT, tableNat}, + } { + if err := r.createAndSetupChain(chainInfo.chain); err != nil { + return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) } } @@ -359,6 +369,10 @@ func (r *router) createContainers() error { return fmt.Errorf("insert established rule: %w", err) } + if err := r.addPostroutingRules(); err != nil { + return fmt.Errorf("add static nat rules: %w", err) + } + if err := r.addJumpRules(); err != nil { return fmt.Errorf("add jump rules: %w", err) } @@ -366,6 +380,32 @@ func (r *router) createContainers() error { return nil } +func (r *router) addPostroutingRules() error { + // First rule for outbound masquerade + rule1 := []string{ + "-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade), + "!", "-o", "lo", + "-j", routingFinalNatJump, + } + if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil { + return fmt.Errorf("add outbound masquerade rule: %v", err) + } + r.rules["static-nat-outbound"] = rule1 + + // Second rule for return traffic masquerade + rule2 := []string{ + "-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn), + "-o", r.wgIface.Name(), + "-j", routingFinalNatJump, + } + if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil { + return fmt.Errorf("add return masquerade rule: %v", err) + } + r.rules["static-nat-return"] = rule2 + + return nil +} + func (r *router) createAndSetupChain(chain string) error { table := r.getTableForChain(chain) @@ -377,10 +417,14 @@ func (r *router) createAndSetupChain(chain string) error { } func (r *router) getTableForChain(chain string) string { - if chain == chainRTNAT { + switch chain { + case chainRTNAT: return tableNat + case chainRTPRE: + return tableMangle + default: + return tableFilter } - return tableFilter } func (r *router) insertEstablishedRule(chain string) error { @@ -398,25 +442,39 @@ func (r *router) insertEstablishedRule(chain string) error { } func (r *router) addJumpRules() error { - rule := []string{"-j", chainRTNAT} - err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...) - if err != nil { - return err + // Jump to NAT chain + natRule := []string{"-j", chainRTNAT} + if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil { + return fmt.Errorf("add nat jump rule: %v", err) } - r.rules[ipv4Nat] = rule + r.rules[jumpNat] = natRule + + // Jump to prerouting chain + preRule := []string{"-j", chainRTPRE} + if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil { + return fmt.Errorf("add prerouting jump rule: %v", err) + } + r.rules[jumpPre] = preRule return nil } func (r *router) cleanJumpRules() error { - rule, found := r.rules[ipv4Nat] - if found { - err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...) - if err != nil { - return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err) + for _, ruleKey := range []string{jumpNat, jumpPre} { + if rule, exists := r.rules[ruleKey]; exists { + table := tableNat + chain := chainPOSTROUTING + if ruleKey == jumpPre { + table = tableMangle + chain = chainPREROUTING + } + + if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil { + return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err) + } + delete(r.rules, ruleKey) } } - return nil } @@ -424,19 +482,35 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { ruleKey := firewall.GenKey(firewall.NatFormat, pair) if rule, exists := r.rules[ruleKey]; exists { - if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { - return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err) + if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil { + return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err) } delete(r.rules, ruleKey) } - rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse) - if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil { - return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err) + markValue := nbnet.PreroutingFwmarkMasquerade + if pair.Inverse { + markValue = nbnet.PreroutingFwmarkMasqueradeReturn + } + + rule := []string{"-i", r.wgIface.Name()} + if pair.Inverse { + rule = []string{"!", "-i", r.wgIface.Name()} + } + + rule = append(rule, + "-m", "conntrack", + "--ctstate", "NEW", + "-s", pair.Source.String(), + "-d", pair.Destination.String(), + "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue), + ) + + if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil { + return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err) } r.rules[ruleKey] = rule - return nil } @@ -444,13 +518,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error { ruleKey := firewall.GenKey(firewall.NatFormat, pair) if rule, exists := r.rules[ruleKey]; exists { - if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { - return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err) + if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil { + return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err) } - delete(r.rules, ruleKey) } else { - log.Debugf("nat rule %s not found", ruleKey) + log.Debugf("marking rule %s not found", ruleKey) } return nil @@ -482,16 +555,6 @@ func (r *router) updateState() { } } -func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { - intdir := "-i" - lointdir := "-o" - if inverse { - intdir = "-o" - lointdir = "-i" - } - return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump} -} - func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { var rule []string diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 2d821a9db..861bf8601 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -3,17 +3,18 @@ package iptables import ( + "fmt" "net/netip" "os/exec" "testing" "github.com/coreos/go-iptables/iptables" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" + nbnet "github.com/netbirdio/netbird/util/net" ) func isIptablesSupported() bool { @@ -34,14 +35,24 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { require.NoError(t, manager.init(nil)) defer func() { - _ = manager.Reset() + assert.NoError(t, manager.Reset(), "shouldn't return error") }() - require.Len(t, manager.rules, 2, "should have created rules map") + // Now 5 rules: + // 1. established rule in forward chain + // 2. jump rule to NAT chain + // 3. jump rule to PRE chain + // 4. static outbound masquerade rule + // 5. static return masquerade rule + require.Len(t, manager.rules, 5, "should have created rules map") - exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...) + exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) - require.True(t, exists, "postrouting rule should exist") + require.True(t, exists, "postrouting jump rule should exist") + + exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING) + require.True(t, exists, "prerouting jump rule should exist") pair := firewall.RouterPair{ ID: "abc", @@ -49,22 +60,15 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { Destination: netip.MustParsePrefix("100.100.100.0/24"), Masquerade: true, } - forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} - err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...) - require.NoError(t, err, "inserting rule should not return error") - - nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false) - - err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...) - require.NoError(t, err, "inserting rule should not return error") + err = manager.AddNatRule(pair) + require.NoError(t, err, "adding NAT rule should not return error") err = manager.Reset() require.NoError(t, err, "shouldn't return error") } func TestIptablesManager_AddNatRule(t *testing.T) { - if !isIptablesSupported() { t.SkipNow() } @@ -79,52 +83,66 @@ func TestIptablesManager_AddNatRule(t *testing.T) { require.NoError(t, manager.init(nil)) defer func() { - err := manager.Reset() - if err != nil { - log.Errorf("failed to reset iptables manager: %s", err) - } + assert.NoError(t, manager.Reset(), "shouldn't return error") }() err = manager.AddNatRule(testCase.InputPair) - require.NoError(t, err, "forwarding pair should be inserted") + require.NoError(t, err, "marking rule should be inserted") natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) - natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) - - exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) - if testCase.InputPair.Masquerade { - require.True(t, exists, "nat rule should be created") - foundNatRule, foundNat := manager.rules[natRuleKey] - require.True(t, foundNat, "nat rule should exist in the map") - require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match") - } else { - require.False(t, exists, "nat rule should not be created") - _, foundNat := manager.rules[natRuleKey] - require.False(t, foundNat, "nat rule should not exist in the map") + markingRule := []string{ + "-i", ifaceMock.Name(), + "-m", "conntrack", + "--ctstate", "NEW", + "-s", testCase.InputPair.Source.String(), + "-d", testCase.InputPair.Destination.String(), + "-j", "MARK", "--set-mark", + fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade), } - inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) - inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) - - exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) + exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE) if testCase.InputPair.Masquerade { - require.True(t, exists, "income nat rule should be created") - foundNatRule, foundNat := manager.rules[inNatRuleKey] - require.True(t, foundNat, "income nat rule should exist in the map") - require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match") + require.True(t, exists, "marking rule should be created") + foundRule, found := manager.rules[natRuleKey] + require.True(t, found, "marking rule should exist in the map") + require.Equal(t, markingRule, foundRule, "stored marking rule should match") } else { - require.False(t, exists, "nat rule should not be created") - _, foundNat := manager.rules[inNatRuleKey] - require.False(t, foundNat, "income nat rule should not exist in the map") + require.False(t, exists, "marking rule should not be created") + _, found := manager.rules[natRuleKey] + require.False(t, found, "marking rule should not exist in the map") + } + + // Check inverse rule + inversePair := firewall.GetInversePair(testCase.InputPair) + inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair) + inverseMarkingRule := []string{ + "!", "-i", ifaceMock.Name(), + "-m", "conntrack", + "--ctstate", "NEW", + "-s", inversePair.Source.String(), + "-d", inversePair.Destination.String(), + "-j", "MARK", "--set-mark", + fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn), + } + + exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE) + if testCase.InputPair.Masquerade { + require.True(t, exists, "inverse marking rule should be created") + foundRule, found := manager.rules[inverseRuleKey] + require.True(t, found, "inverse marking rule should exist in the map") + require.Equal(t, inverseMarkingRule, foundRule, "stored inverse marking rule should match") + } else { + require.False(t, exists, "inverse marking rule should not be created") + _, found := manager.rules[inverseRuleKey] + require.False(t, found, "inverse marking rule should not exist in the map") } }) } } func TestIptablesManager_RemoveNatRule(t *testing.T) { - if !isIptablesSupported() { t.SkipNow() } @@ -137,42 +155,52 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) { require.NoError(t, err, "shouldn't return error") require.NoError(t, manager.init(nil)) defer func() { - _ = manager.Reset() + assert.NoError(t, manager.Reset(), "shouldn't return error") }() - require.NoError(t, err, "shouldn't return error") - - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) - natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) - - err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...) - require.NoError(t, err, "inserting rule should not return error") - - inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) - inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) - - err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...) - require.NoError(t, err, "inserting rule should not return error") - - err = manager.Reset() - require.NoError(t, err, "shouldn't return error") + err = manager.AddNatRule(testCase.InputPair) + require.NoError(t, err, "should add NAT rule without error") err = manager.RemoveNatRule(testCase.InputPair) require.NoError(t, err, "shouldn't return error") - exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) - require.False(t, exists, "nat rule should not exist") + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) + markingRule := []string{ + "-i", ifaceMock.Name(), + "-m", "conntrack", + "--ctstate", "NEW", + "-s", testCase.InputPair.Source.String(), + "-d", testCase.InputPair.Destination.String(), + "-j", "MARK", "--set-mark", + fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade), + } + + exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE) + require.False(t, exists, "marking rule should not exist") _, found := manager.rules[natRuleKey] - require.False(t, found, "nat rule should exist in the manager map") + require.False(t, found, "marking rule should not exist in the manager map") - exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) - require.False(t, exists, "income nat rule should not exist") + // Check inverse rule removal + inversePair := firewall.GetInversePair(testCase.InputPair) + inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair) + inverseMarkingRule := []string{ + "!", "-i", ifaceMock.Name(), + "-m", "conntrack", + "--ctstate", "NEW", + "-s", inversePair.Source.String(), + "-d", inversePair.Destination.String(), + "-j", "MARK", "--set-mark", + fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn), + } - _, found = manager.rules[inNatRuleKey] - require.False(t, found, "income nat rule should exist in the manager map") + exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE) + require.False(t, exists, "inverse marking rule should not exist") + + _, found = manager.rules[inverseRuleKey] + require.False(t, found, "inverse marking rule should not exist in the map") }) } } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 2a40cd9f6..9391b47ec 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -17,6 +17,7 @@ import ( const ( ForwardingFormatPrefix = "netbird-fwd-" ForwardingFormat = "netbird-fwd-%s-%t" + PreroutingFormat = "netbird-prerouting-%s-%t" NatFormat = "netbird-nat-%s-%t" ) diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index ca7b2e59f..abe890fb9 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -520,7 +520,7 @@ func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) { }, &expr.Immediate{ Register: 1, - Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected), }, &expr.Meta{ Key: expr.MetaKeyMARK, @@ -543,7 +543,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) { &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, - Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected), }, &expr.Verdict{ Kind: expr.VerdictJump, diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 0e7ea71b7..34bc9a9bc 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -21,6 +21,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -124,7 +125,6 @@ func (r *router) createContainers() error { insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) prio := *nftables.ChainPriorityNATSource - 1 - r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingNat, Table: r.workTable, @@ -133,6 +133,21 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeNAT, }) + // Chain is created by acl manager + // TODO: move creation to a common place + r.chains[chainNamePrerouting] = &nftables.Chain{ + Name: chainNamePrerouting, + Table: r.workTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityMangle, + } + + // Add the single NAT rule that matches on mark + if err := r.addPostroutingRules(); err != nil { + return fmt.Errorf("add single nat rule: %v", err) + } + if err := r.acceptForwardRules(); err != nil { log.Errorf("failed to add accept rules for the forward chain: %s", err) } @@ -422,59 +437,149 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { sourceExp := generateCIDRMatcherExpressions(true, pair.Source) destExp := generateCIDRMatcherExpressions(false, pair.Destination) - dir := expr.MetaKeyIIFNAME - notDir := expr.MetaKeyOIFNAME + op := expr.CmpOpEq if pair.Inverse { - dir = expr.MetaKeyOIFNAME - notDir = expr.MetaKeyIIFNAME + op = expr.CmpOpNeq } - lo := ifname("lo") - intf := ifname(r.wgIface.Name()) - exprs := []expr.Any{ - &expr.Meta{ - Key: dir, + // We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading. + // Masquerading will take care of the conntrack state, which means we won't need to mark established connections. + &expr.Ct{ + Key: expr.CtKeySTATE, Register: 1, }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: intf, - }, - - // We need to exclude the loopback interface as this changes the ebpf proxy port - &expr.Meta{ - Key: notDir, - Register: 1, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW), + Xor: binaryutil.NativeEndian.PutUint32(0), }, &expr.Cmp{ Op: expr.CmpOpNeq, Register: 1, - Data: lo, + Data: []byte{0, 0, 0, 0}, + }, + + // interface matching + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: op, + Register: 1, + Data: ifname(r.wgIface.Name()), }, } exprs = append(exprs, sourceExp...) exprs = append(exprs, destExp...) + + var markValue uint32 = nbnet.PreroutingFwmarkMasquerade + if pair.Inverse { + markValue = nbnet.PreroutingFwmarkMasqueradeReturn + } + exprs = append(exprs, - &expr.Counter{}, &expr.Masq{}, + &expr.Immediate{ + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(markValue), + }, + &expr.Meta{ + Key: expr.MetaKeyMARK, + SourceRegister: true, + Register: 1, + }, ) - ruleKey := firewall.GenKey(firewall.NatFormat, pair) + ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair) if _, exists := r.rules[ruleKey]; exists { if err := r.removeNatRule(pair); err != nil { - return fmt.Errorf("remove routing rule: %w", err) + return fmt.Errorf("remove prerouting rule: %w", err) } } r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ Table: r.workTable, - Chain: r.chains[chainNameRoutingNat], + Chain: r.chains[chainNamePrerouting], Exprs: exprs, UserData: []byte(ruleKey), }) + + return nil +} + +// addPostroutingRules adds the masquerade rules +func (r *router) addPostroutingRules() error { + // First masquerade rule for traffic coming in from WireGuard interface + exprs := []expr.Any{ + // Match on the first fwmark + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade), + }, + + // We need to exclude the loopback interface as this changes the ebpf proxy port + &expr.Meta{ + Key: expr.MetaKeyOIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: ifname("lo"), + }, + &expr.Counter{}, + &expr.Masq{}, + } + + r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingNat], + Exprs: exprs, + }) + + // Second masquerade rule for traffic going out through WireGuard interface + exprs2 := []expr.Any{ + // Match on the second fwmark + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn), + }, + + // Match WireGuard interface + &expr.Meta{ + Key: expr.MetaKeyOIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + &expr.Counter{}, + &expr.Masq{}, + } + + r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingNat], + Exprs: exprs2, + }) + return nil } @@ -723,18 +828,18 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error return nberrors.FormatErrorOrNil(merr) } -// RemoveNatRule removes a nftables rule pair from nat chains +// RemoveNatRule removes the prerouting mark rule func (r *router) RemoveNatRule(pair firewall.RouterPair) error { if err := r.refreshRulesMap(); err != nil { return fmt.Errorf(refreshRulesMapError, err) } if err := r.removeNatRule(pair); err != nil { - return fmt.Errorf("remove nat rule: %w", err) + return fmt.Errorf("remove prerouting rule: %w", err) } if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { - return fmt.Errorf("remove inverse nat rule: %w", err) + return fmt.Errorf("remove inverse prerouting rule: %w", err) } if err := r.removeLegacyRouteRule(pair); err != nil { @@ -749,21 +854,20 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return nil } -// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map func (r *router) removeNatRule(pair firewall.RouterPair) error { - ruleKey := firewall.GenKey(firewall.NatFormat, pair) + ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair) if rule, exists := r.rules[ruleKey]; exists { err := r.conn.DelRule(rule) if err != nil { - return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err) + return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err) } - log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination) + log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination) delete(r.rules, ruleKey) } else { - log.Debugf("nftables: nat rule %s not found", ruleKey) + log.Debugf("nftables: prerouting rule %s not found", ruleKey) } return nil diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 19ed48991..afc4d5c39 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -10,6 +10,7 @@ import ( "github.com/coreos/go-iptables/iptables" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -32,100 +33,87 @@ func TestNftablesManager_AddNatRule(t *testing.T) { t.Skip("nftables not supported on this OS") } - table, err := createWorkTable() - require.NoError(t, err, "Failed to create work table") - - defer deleteWorkTable() - for _, testCase := range test.InsertRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(table, ifaceMock) - require.NoError(t, err, "failed to create router") - require.NoError(t, manager.init(table)) + // need fw manager to init both acl mgr and router for all chains to be present + manager, err := Create(ifaceMock) + t.Cleanup(func() { + require.NoError(t, manager.Reset(nil)) + }) + require.NoError(t, err) + require.NoError(t, manager.Init(nil)) nftablesTestingClient := &nftables.Conn{} - defer func(manager *router) { - require.NoError(t, manager.Reset(), "failed to reset rules") - }(manager) - - require.NoError(t, err, "shouldn't return error") - - err = manager.AddNatRule(testCase.InputPair) + rtr := manager.router + err = rtr.AddNatRule(testCase.InputPair) require.NoError(t, err, "pair should be inserted") - defer func(manager *router, pair firewall.RouterPair) { - require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule") - }(manager, testCase.InputPair) + t.Cleanup(func() { + require.NoError(t, rtr.RemoveNatRule(testCase.InputPair), "failed to remove rule") + }) if testCase.InputPair.Masquerade { - sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) - destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) - testingExpression := append(sourceExp, destExp...) //nolint:gocritic - testingExpression = append(testingExpression, - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + // Build expected expressions for connection tracking + conntrackExprs := []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: 1, + }, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0, 0, 0, 0}, + }, + } + + // Build interface matching expression + ifaceExprs := []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: ifname(ifaceMock.Name()), }, - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 1, - Data: ifname("lo"), - }, - ) - - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) - found := 0 - for _, chain := range manager.chains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { - require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match") - found = 1 - } - } } - require.Equal(t, 1, found, "should find at least 1 rule to test") - } - if testCase.InputPair.Masquerade { + // Build CIDR matching expressions sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) - testingExpression := append(sourceExp, destExp...) //nolint:gocritic - testingExpression = append(testingExpression, - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(ifaceMock.Name()), - }, - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 1, - Data: ifname("lo"), - }, - ) - inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) + // Combine all expressions in the correct order + // nolint:gocritic + testingExpression := append(conntrackExprs, ifaceExprs...) + testingExpression = append(testingExpression, sourceExp...) + testingExpression = append(testingExpression, destExp...) + + natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair) found := 0 - for _, chain := range manager.chains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey { - require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match") - found = 1 + for _, chain := range rtr.chains { + if chain.Name == chainNamePrerouting { + rules, err := nftablesTestingClient.GetRules(chain.Table, chain) + require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) + for _, rule := range rules { + if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { + // Compare expressions up to the mark setting expressions + require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match") + found = 1 + } } } } - require.Equal(t, 1, found, "should find at least 1 rule to test") + require.Equal(t, 1, found, "should find at least 1 rule in prerouting chain") } - }) } } @@ -135,68 +123,66 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) { t.Skip("nftables not supported on this OS") } - table, err := createWorkTable() - require.NoError(t, err, "Failed to create work table") - - defer deleteWorkTable() - for _, testCase := range test.RemoveRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(table, ifaceMock) - require.NoError(t, err, "failed to create router") - require.NoError(t, manager.init(table)) - - nftablesTestingClient := &nftables.Conn{} - - defer func(manager *router) { - require.NoError(t, manager.Reset(), "failed to reset rules") - }(manager) - - sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) - destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) - - natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) - - insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: manager.workTable, - Chain: manager.chains[chainNameRoutingNat], - Exprs: natExp, - UserData: []byte(natRuleKey), + manager, err := Create(ifaceMock) + t.Cleanup(func() { + require.NoError(t, manager.Reset(nil)) }) + require.NoError(t, err) + require.NoError(t, manager.Init(nil)) - sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source) - destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination) + rtr := manager.router - natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) + // First add the NAT rule using the router's method + err = rtr.AddNatRule(testCase.InputPair) + require.NoError(t, err, "should add NAT rule") - insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: manager.workTable, - Chain: manager.chains[chainNameRoutingNat], - Exprs: natExp, - UserData: []byte(inNatRuleKey), - }) - - err = nftablesTestingClient.Flush() - require.NoError(t, err, "shouldn't return error") - - err = manager.Reset() - require.NoError(t, err, "shouldn't return error") - - err = manager.RemoveNatRule(testCase.InputPair) - require.NoError(t, err, "shouldn't return error") - - for _, chain := range manager.chains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 { - require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist") - require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist") - } + // Verify the rule was added + natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair) + found := false + rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting]) + require.NoError(t, err, "should list rules") + for _, rule := range rules { + if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { + found = true + break } } + require.True(t, found, "NAT rule should exist before removal") + + // Now remove the rule + err = rtr.RemoveNatRule(testCase.InputPair) + require.NoError(t, err, "shouldn't return error when removing rule") + + // Verify the rule was removed + found = false + rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting]) + require.NoError(t, err, "should list rules after removal") + for _, rule := range rules { + if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { + found = true + break + } + } + require.False(t, found, "NAT rule should not exist after removal") + + // Verify the static postrouting rules still exist + rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameRoutingNat]) + require.NoError(t, err, "should list postrouting rules") + foundCounter := false + for _, rule := range rules { + for _, e := range rule.Exprs { + if _, ok := e.(*expr.Counter); ok { + foundCounter = true + break + } + } + if foundCounter { + break + } + } + require.True(t, foundCounter, "static postrouting rule should remain") }) } } diff --git a/client/server/server.go b/client/server/server.go index 4d921851f..106bdf32b 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -626,6 +626,8 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes s.mutex.Lock() defer s.mutex.Unlock() + s.oauthAuthFlow = oauthAuthFlow{} + if s.actCancel == nil { return nil, fmt.Errorf("service is not up") } diff --git a/go.mod b/go.mod index 0ff8b39fe..15851e842 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd + github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 8fb664579..7e10ea160 100644 --- a/go.sum +++ b/go.sum @@ -524,8 +524,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= +github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 h1:L8mNd3tBxMdnQNxMNJ+/EiwHwizNOMy8/nHLVGNfjpg= +github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= diff --git a/management/server/account.go b/management/server/account.go index 1810c6b41..aa7609388 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1249,7 +1249,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context account, err := am.Store.GetAccount(ctx, accountID) if err != nil { - log.Errorf("failed getting account %s expiring peers", account.Id) + log.Errorf("failed getting account %s expiring peers", accountID) return account.GetNextInactivePeerExpiration() } diff --git a/management/server/account_test.go b/management/server/account_test.go index 1cd4ae449..31405e3af 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -29,14 +29,18 @@ import ( ) type MocIntegratedValidator struct { + ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) } func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { return nil } -func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { - return update, nil +func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) { + if a.ValidatePeerFunc != nil { + return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings) + } + return update, false, nil } func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { validatedPeers := make(map[string]struct{}) diff --git a/management/server/integrated_validator/interface.go b/management/server/integrated_validator/interface.go index 6c9a3e44e..03be9d039 100644 --- a/management/server/integrated_validator/interface.go +++ b/management/server/integrated_validator/interface.go @@ -11,7 +11,7 @@ import ( // IntegratedValidator interface exists to avoid the circle dependencies type IntegratedValidator interface { ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error - ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) + ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) diff --git a/management/server/management_test.go b/management/server/management_test.go index d53c177d6..5361da53f 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -453,8 +453,8 @@ func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtr return nil } -func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { - return update, nil +func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) { + return update, false, nil } func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { diff --git a/management/server/peer.go b/management/server/peer.go index 7cc2209c5..9c5ab571b 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -189,7 +189,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID) } - update, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + var requiresPeerUpdates bool + update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) if err != nil { return nil, err } @@ -265,7 +266,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return nil, err } - if peerLabelUpdated { + if peerLabelUpdated || requiresPeerUpdates { am.updateAccountPeers(ctx, account) } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 5127f77fb..78885ea1b 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -22,6 +22,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" + nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -1398,6 +1399,50 @@ func TestPeerAccountPeersUpdate(t *testing.T) { } }) + t.Run("validator requires update", func(t *testing.T) { + requireUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *nbAccount.ExtraSettings) (*nbpeer.Peer, bool, error) { + return update, true, nil + } + + manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireUpdateFunc} + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + t.Run("validator requires no update", func(t *testing.T) { + requireNoUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *nbAccount.ExtraSettings) (*nbpeer.Peer, bool, error) { + return update, false, nil + } + + manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc} + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + // Adding peer to group linked with policy should update account peers and send peer update t.Run("adding peer to group linked with policy", func(t *testing.T) { err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 9cfad5510..b84d928cb 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -70,9 +70,17 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr if err != nil { conns = runtime.NumCPU() } + + if storeEngine == SqliteStoreEngine { + if err == nil { + log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1") + } + conns = 1 + } + sql.SetMaxOpenConns(conns) - log.Infof("Set max open db connections to %d", conns) + log.WithContext(ctx).Infof("Set max open db connections to %d", conns) if storeEngine == MysqlStoreEngine { sql.SetConnMaxLifetime(120) diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 10bfbe44d..1ad57d27a 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -63,13 +63,14 @@ func (l *Listener) Shutdown(ctx context.Context) error { } func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { + connRemoteAddr := remoteAddr(r) wsConn, err := websocket.Accept(w, r, nil) if err != nil { - log.Errorf("failed to accept ws connection from %s: %s", r.RemoteAddr, err) + log.Errorf("failed to accept ws connection from %s: %s", connRemoteAddr, err) return } - rAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr) + rAddr, err := net.ResolveTCPAddr("tcp", connRemoteAddr) if err != nil { err = wsConn.Close(websocket.StatusInternalError, "internal error") if err != nil { @@ -90,3 +91,10 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { conn := NewConn(wsConn, lAddr, rAddr) l.acceptFn(conn) } + +func remoteAddr(r *http.Request) string { + if r.Header.Get("X-Real-Ip") == "" || r.Header.Get("X-Real-Port") == "" { + return r.RemoteAddr + } + return fmt.Sprintf("%s:%s", r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port")) +} diff --git a/util/net/net.go b/util/net/net.go index 035d7552b..5448eb85a 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -11,8 +11,11 @@ import ( const ( // NetbirdFwmark is the fwmark value used by Netbird via wireguard - NetbirdFwmark = 0x1BD00 - PreroutingFwmark = 0x1BD01 + NetbirdFwmark = 0x1BD00 + + PreroutingFwmarkRedirected = 0x1BD01 + PreroutingFwmarkMasquerade = 0x1BD11 + PreroutingFwmarkMasqueradeReturn = 0x1BD12 envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" )