Compare commits

...

43 Commits

Author SHA1 Message Date
Viktor Liu
3efa7a282a Set log level debug 2024-11-27 13:46:37 +01:00
Viktor Liu
40551099b3 Add debug 2024-11-27 13:41:32 +01:00
Pascal Fischer
9db1932664 [management] Fix getSetupKey call (#2927) 2024-11-22 10:15:51 +01:00
Viktor Liu
1bbabf70b0 [client] Fix allow netbird rule verdict (#2925)
* Fix allow netbird rule verdict

* Fix chain name
2024-11-21 16:53:37 +01:00
Pascal Fischer
aa575d6f44 [management] Add activity events to group propagation flow (#2916) 2024-11-21 15:10:34 +01:00
Pascal Fischer
f66bbcc54c [management] Add metric for peer meta update (#2913) 2024-11-19 18:13:26 +01:00
Pascal Fischer
5dd6a08ea6 link peer meta update back to account object (#2911) 2024-11-19 17:25:49 +01:00
Krzysztof Nazarewski (kdn)
eb5d0569ae [client] Add NB_SKIP_SOCKET_MARK & fix crash instead of returing an error (#2899)
* dialer: fix crash instead of returning error

* add NB_SKIP_SOCKET_MARK
2024-11-19 14:14:58 +01:00
Pascal Fischer
52ea2e84e9 [management] Add transaction metrics and exclude getAccount time from peers update (#2904) 2024-11-19 00:04:50 +01:00
Maycon Santos
78fab877c0 [misc] Update signing pipeline version (#2900) 2024-11-18 15:31:53 +01:00
Maycon Santos
65a94f695f use google domain for tests (#2902) 2024-11-18 12:55:02 +01:00
Kursat Aktas
ec543f89fb Introducing NetBird Guru on Gurubase.io (#2778) 2024-11-16 15:45:31 +01:00
Viktor Liu
a7d5c52203 Fix error state race on mgmt connection error (#2892) 2024-11-15 22:59:49 +01:00
Viktor Liu
582bb58714 Move state updates outside the refcounter (#2897) 2024-11-15 22:55:33 +01:00
Viktor Liu
121dfda915 [client] Fix state manager race conditions (#2890) 2024-11-15 20:05:26 +01:00
İsmail
a1c5287b7c Fix the Inactivity Expiration problem. (#2865) 2024-11-15 18:21:27 +01:00
Bethuel Mmbaga
12f442439a [management] Refactor group to use store methods (#2867)
* Refactor setup key handling to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add lock to get account groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add check for regular user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get only required groups for auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add account lock and return auto groups map on validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor account peers update

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor groups to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor GetGroupByID and add NewGroupNotFoundError

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add AddPeer and RemovePeer methods to Group struct

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Preserve store engine in SqlStore transactions

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run groups ops in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix missing group removed from setup key activity

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Change setup key log level to debug for missing group

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve modified peers once for group events

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add account locking and merge group deletion methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-15 20:09:32 +03:00
Pascal Fischer
d9b691b8a5 [management] Limit the setup-key update operation (#2841) 2024-11-15 17:00:06 +01:00
Pascal Fischer
4aee3c9e33 [client/management] add peer lock to peer meta update and fix isEqual func (#2840) 2024-11-15 16:59:03 +01:00
Pascal Fischer
44e799c687 [management] Fix limited peer view groups (#2894) 2024-11-15 11:16:16 +01:00
Viktor Liu
be78efbd42 [client] Handle panic on nil wg interface (#2891) 2024-11-14 20:15:16 +01:00
Maycon Santos
6886691213 Update route calculation tests (#2884)
- Add two new test cases for p2p and relay routes with same latency
- Add extra statuses generation
2024-11-13 15:21:33 +01:00
Zoltan Papp
b48afd92fd [relay-server] Always close ws conn when work thread exit (#2879)
Close ws conn when work thread exit
2024-11-13 15:02:51 +01:00
Viktor Liu
39329e12a1 [client] Improve state write timeout and abort work early on timeout (#2882)
* Improve state write timeout and abort work early on timeout

* Don't block on initial persist state
2024-11-13 13:46:00 +01:00
Pascal Fischer
20a5afc359 [management] Add more logs to the peer update processes (#2881) 2024-11-12 14:19:22 +01:00
Bethuel Mmbaga
6cb697eed6 [management] Refactor setup key to use store methods (#2861)
* Refactor setup key handling to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add lock to get account groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add check for regular user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get only required groups for auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add account lock and return auto groups map on validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix missing group removed from setup key activity

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Remove context from DB queries

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add user permission check and add setup events into events to store slice

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve all groups once during setup key auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-11 19:46:10 +03:00
Viktor Liu
e0bed2b0fb [client] Fix race conditions (#2869)
* Fix concurrent map access in status

* Fix race when retrieving ctx state error

* Fix race when accessing service controller server instance
2024-11-11 14:55:10 +01:00
Zoltan Papp
30f025e7dd [client] fix/proxy close (#2873)
When the remote peer switches the Relay instance then must to close the proxy connection to the old instance.

It can cause issues when the remote peer switch connects to the Relay instance multiple times and then reconnects to an instance it had previously connected to.
2024-11-11 14:18:38 +01:00
Zoltan Papp
b4d7605147 [client] Remove loop after route calculation (#2856)
- ICE do not trigger disconnect callbacks if the stated did not change
- Fix route calculation callback loop
- Move route state updates into protected scope by mutex
- Do not calculate routes in case of peer.Open() and peer.Close()
2024-11-11 10:53:57 +01:00
Viktor Liu
08b6e9d647 [management] Fix api error message typo peers_group (#2862) 2024-11-08 23:28:02 +01:00
Pascal Fischer
67ce14eaea [management] Add peer lock to grpc server (#2859)
* add peer lock to grpc server

* remove sleep and put db update first

* don't export lock method
2024-11-08 18:47:22 +01:00
Pascal Fischer
669904cd06 [management] Remove context from database calls (#2863) 2024-11-08 15:49:00 +01:00
Zoltan Papp
4be826450b [client] Use offload in WireGuard bind receiver (#2815)
Improve the performance on Linux and Android in case of P2P connections
2024-11-07 17:28:38 +01:00
Maycon Santos
738387f2de Add benchmark tests to get account with claims (#2761)
* Add benchmark tests to get account with claims

* add users to account objects

* remove hardcoded env
2024-11-07 17:23:35 +01:00
Pascal Fischer
baf0678ceb [management] Fix potential panic on inactivity expiration log message (#2854) 2024-11-07 16:33:57 +01:00
Pascal Fischer
7fef8f6758 [management] Enforce max conn of 1 for sqlite setups (#2855) 2024-11-07 16:32:35 +01:00
Viktor Liu
6829a64a2d [client] Exclude split default route ip addresses from anonymization (#2853) 2024-11-07 16:29:32 +01:00
Zoltan Papp
cbf500024f [relay-server] Use X-Real-IP in case of reverse proxy (#2848)
* Use X-Real-IP in case of reverse proxy

* Use sprintf
2024-11-07 16:14:53 +01:00
Viktor Liu
509e184e10 [client] Use the prerouting chain to mark for masquerading to support older systems (#2808) 2024-11-07 12:37:04 +01:00
Pascal Fischer
3e88b7c56e [management] Fix network map update on peer validation (#2849) 2024-11-07 09:50:13 +01:00
Maycon Santos
b952d8693d Fix cached device flow oauth (#2833)
This change removes the cached device flow oauth info when a down command is called

Removing the need for the agent to be restarted
2024-11-05 14:51:17 +01:00
Maycon Santos
5b46cc8e9c Avoid failing all other matrix tests if one fails (#2839) 2024-11-05 13:28:42 +01:00
Pascal Fischer
a9d06b883f add all group to add peer affected peers network map check (#2830) 2024-11-01 22:09:08 +01:00
88 changed files with 2957 additions and 1418 deletions

View File

@@ -13,6 +13,7 @@ concurrency:
jobs:
test:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres']

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.16"
SIGN_PIPE_VER: "v0.0.17"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"

View File

@@ -19,6 +19,10 @@
<br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
</a>
</p>
</div>

View File

@@ -201,6 +201,8 @@ func isWellKnown(addr netip.Addr) bool {
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
"128.0.0.0", "8000::", // 2nd split subnet for default routes
}
if slices.Contains(wellKnown, addr.String()) {

View File

@@ -2,6 +2,7 @@ package cmd
import (
"context"
"sync"
"github.com/kardianos/service"
log "github.com/sirupsen/logrus"
@@ -13,10 +14,11 @@ import (
)
type program struct {
ctx context.Context
cancel context.CancelFunc
serv *grpc.Server
serverInstance *server.Server
ctx context.Context
cancel context.CancelFunc
serv *grpc.Server
serverInstance *server.Server
serverInstanceMu sync.Mutex
}
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {

View File

@@ -61,7 +61,9 @@ func (p *program) Start(svc service.Service) error {
}
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
p.serverInstanceMu.Lock()
p.serverInstance = serverInstance
p.serverInstanceMu.Unlock()
log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil {
@@ -72,6 +74,7 @@ func (p *program) Start(svc service.Service) error {
}
func (p *program) Stop(srv service.Service) error {
p.serverInstanceMu.Lock()
if p.serverInstance != nil {
in := new(proto.DownRequest)
_, err := p.serverInstance.Down(p.ctx, in)
@@ -79,6 +82,7 @@ func (p *program) Stop(srv service.Service) error {
log.Errorf("failed to stop daemon: %v", err)
}
}
p.serverInstanceMu.Unlock()
p.cancel()

View File

@@ -352,14 +352,14 @@ func (m *aclManager) seedInitialEntries() {
func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules},
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
position: 2,
},
}
m.optionalEntries["PREROUTING"] = []entry{
{
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)},
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
position: 1,
},
}

View File

@@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}
// persist early to ensure cleanup of chains
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()
return nil
}

View File

@@ -18,22 +18,24 @@ import (
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
ipv4Nat = "netbird-rt-nat"
nbnet "github.com/netbirdio/netbird/util/net"
)
// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableNat = "nat"
tableMangle = "mangle"
chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD"
chainRTPRE = "NETBIRD-RT-PRE"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
jumpPre = "jump-pre"
jumpNat = "jump-nat"
matchSet = "--match-set"
)
@@ -323,24 +325,25 @@ func (r *router) Reset() error {
}
func (r *router) cleanUpDefaultForwardRules() error {
err := r.cleanJumpRules()
if err != nil {
return err
if err := r.cleanJumpRules(); err != nil {
return fmt.Errorf("clean jump rules: %w", err)
}
log.Debug("flushing routing related tables")
for _, chain := range []string{chainRTFWD, chainRTNAT} {
table := r.getTableForChain(chain)
ok, err := r.iptablesClient.ChainExists(table, chain)
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTNAT, tableNat},
{chainRTPRE, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil {
log.Errorf("failed check chain %s, error: %v", chain, err)
return err
return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
} else if ok {
err = r.iptablesClient.ClearAndDeleteChain(table, chain)
if err != nil {
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
return err
if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}
}
@@ -349,9 +352,16 @@ func (r *router) cleanUpDefaultForwardRules() error {
}
func (r *router) createContainers() error {
for _, chain := range []string{chainRTFWD, chainRTNAT} {
if err := r.createAndSetupChain(chain); err != nil {
return fmt.Errorf("create chain %s: %w", chain, err)
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
} {
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}
@@ -359,6 +369,10 @@ func (r *router) createContainers() error {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add static nat rules: %w", err)
}
if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err)
}
@@ -366,6 +380,32 @@ func (r *router) createContainers() error {
return nil
}
func (r *router) addPostroutingRules() error {
// First rule for outbound masquerade
rule1 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
"!", "-o", "lo",
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
return fmt.Errorf("add outbound masquerade rule: %v", err)
}
r.rules["static-nat-outbound"] = rule1
// Second rule for return traffic masquerade
rule2 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
"-o", r.wgIface.Name(),
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
return fmt.Errorf("add return masquerade rule: %v", err)
}
r.rules["static-nat-return"] = rule2
return nil
}
func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain)
@@ -377,10 +417,14 @@ func (r *router) createAndSetupChain(chain string) error {
}
func (r *router) getTableForChain(chain string) string {
if chain == chainRTNAT {
switch chain {
case chainRTNAT:
return tableNat
case chainRTPRE:
return tableMangle
default:
return tableFilter
}
return tableFilter
}
func (r *router) insertEstablishedRule(chain string) error {
@@ -398,25 +442,39 @@ func (r *router) insertEstablishedRule(chain string) error {
}
func (r *router) addJumpRules() error {
rule := []string{"-j", chainRTNAT}
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if err != nil {
return err
// Jump to NAT chain
natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat jump rule: %v", err)
}
r.rules[ipv4Nat] = rule
r.rules[jumpNat] = natRule
// Jump to prerouting chain
preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add prerouting jump rule: %v", err)
}
r.rules[jumpPre] = preRule
return nil
}
func (r *router) cleanJumpRules() error {
rule, found := r.rules[ipv4Nat]
if found {
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil {
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err)
for _, ruleKey := range []string{jumpNat, jumpPre} {
if rule, exists := r.rules[ruleKey]; exists {
table := tableNat
chain := chainPOSTROUTING
if ruleKey == jumpPre {
table = tableMangle
chain = chainPREROUTING
}
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err)
}
delete(r.rules, ruleKey)
}
}
return nil
}
@@ -424,19 +482,35 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
}
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse)
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
markValue := nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
rule := []string{"-i", r.wgIface.Name()}
if pair.Inverse {
rule = []string{"!", "-i", r.wgIface.Name()}
}
rule = append(rule,
"-m", "conntrack",
"--ctstate", "NEW",
"-s", pair.Source.String(),
"-d", pair.Destination.String(),
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
)
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
}
r.rules[ruleKey] = rule
return nil
}
@@ -444,13 +518,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
} else {
log.Debugf("nat rule %s not found", ruleKey)
log.Debugf("marking rule %s not found", ruleKey)
}
return nil
@@ -482,16 +555,6 @@ func (r *router) updateState() {
}
}
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
intdir := "-i"
lointdir := "-o"
if inverse {
intdir = "-o"
lointdir = "-i"
}
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string

View File

@@ -3,17 +3,18 @@
package iptables
import (
"fmt"
"net/netip"
"os/exec"
"testing"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/util/net"
)
func isIptablesSupported() bool {
@@ -34,14 +35,24 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.NoError(t, manager.init(nil))
defer func() {
_ = manager.Reset()
assert.NoError(t, manager.Reset(), "shouldn't return error")
}()
require.Len(t, manager.rules, 2, "should have created rules map")
// Now 5 rules:
// 1. established rule in forward chain
// 2. jump rule to NAT chain
// 3. jump rule to PRE chain
// 4. static outbound masquerade rule
// 5. static return masquerade rule
require.Len(t, manager.rules, 5, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting rule should exist")
require.True(t, exists, "postrouting jump rule should exist")
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
require.True(t, exists, "prerouting jump rule should exist")
pair := firewall.RouterPair{
ID: "abc",
@@ -49,22 +60,15 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
Destination: netip.MustParsePrefix("100.100.100.0/24"),
Masquerade: true,
}
forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error")
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.AddNatRule(pair)
require.NoError(t, err, "adding NAT rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
}
func TestIptablesManager_AddNatRule(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
@@ -79,52 +83,66 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
require.NoError(t, manager.init(nil))
defer func() {
err := manager.Reset()
if err != nil {
log.Errorf("failed to reset iptables manager: %s", err)
}
assert.NoError(t, manager.Reset(), "shouldn't return error")
}()
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "forwarding pair should be inserted")
require.NoError(t, err, "marking rule should be inserted")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "nat rule should be created")
foundNatRule, foundNat := manager.rules[natRuleKey]
require.True(t, foundNat, "nat rule should exist in the map")
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[natRuleKey]
require.False(t, foundNat, "nat rule should not exist in the map")
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", testCase.InputPair.Source.String(),
"-d", testCase.InputPair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
if testCase.InputPair.Masquerade {
require.True(t, exists, "income nat rule should be created")
foundNatRule, foundNat := manager.rules[inNatRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
require.True(t, exists, "marking rule should be created")
foundRule, found := manager.rules[natRuleKey]
require.True(t, found, "marking rule should exist in the map")
require.Equal(t, markingRule, foundRule, "stored marking rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[inNatRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map")
require.False(t, exists, "marking rule should not be created")
_, found := manager.rules[natRuleKey]
require.False(t, found, "marking rule should not exist in the map")
}
// Check inverse rule
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
if testCase.InputPair.Masquerade {
require.True(t, exists, "inverse marking rule should be created")
foundRule, found := manager.rules[inverseRuleKey]
require.True(t, found, "inverse marking rule should exist in the map")
require.Equal(t, inverseMarkingRule, foundRule, "stored inverse marking rule should match")
} else {
require.False(t, exists, "inverse marking rule should not be created")
_, found := manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
}
})
}
}
func TestIptablesManager_RemoveNatRule(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
@@ -137,42 +155,52 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() {
_ = manager.Reset()
assert.NoError(t, manager.Reset(), "shouldn't return error")
}()
require.NoError(t, err, "shouldn't return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule without error")
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "nat rule should not exist")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", testCase.InputPair.Source.String(),
"-d", testCase.InputPair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "marking rule should not exist")
_, found := manager.rules[natRuleKey]
require.False(t, found, "nat rule should exist in the manager map")
require.False(t, found, "marking rule should not exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "income nat rule should not exist")
// Check inverse rule removal
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
_, found = manager.rules[inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "inverse marking rule should not exist")
_, found = manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
})
}
}

View File

@@ -17,6 +17,7 @@ import (
const (
ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t"
PreroutingFormat = "netbird-prerouting-%s-%t"
NatFormat = "netbird-nat-%s-%t"
)

View File

@@ -520,7 +520,7 @@ func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
@@ -543,7 +543,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Verdict{
Kind: expr.VerdictJump,

View File

@@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}
// persist early
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()
return nil
}
@@ -197,7 +199,7 @@ func (m *Manager) AllowNetbird() error {
var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
chain = c
break
}
@@ -274,7 +276,7 @@ func (m *Manager) resetNetbirdInputRules() error {
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains {
if c.Table.Name == "filter" && c.Name == "INPUT" {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
@@ -349,7 +351,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: []byte(allowNetbirdInputRuleID),
}

View File

@@ -21,6 +21,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
@@ -124,7 +125,6 @@ func (r *router) createContainers() error {
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
@@ -133,6 +133,21 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT,
})
// Chain is created by acl manager
// TODO: move creation to a common place
r.chains[chainNamePrerouting] = &nftables.Chain{
Name: chainNamePrerouting,
Table: r.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
}
// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err)
}
if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
@@ -422,59 +437,149 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
dir := expr.MetaKeyIIFNAME
notDir := expr.MetaKeyOIFNAME
op := expr.CmpOpEq
if pair.Inverse {
dir = expr.MetaKeyOIFNAME
notDir = expr.MetaKeyIIFNAME
op = expr.CmpOpNeq
}
lo := ifname("lo")
intf := ifname(r.wgIface.Name())
exprs := []expr.Any{
&expr.Meta{
Key: dir,
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: notDir,
Register: 1,
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: lo,
Data: []byte{0, 0, 0, 0},
},
// interface matching
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: op,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
exprs = append(exprs,
&expr.Counter{}, &expr.Masq{},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(markValue),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
)
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove routing rule: %w", err)
return fmt.Errorf("remove prerouting rule: %w", err)
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Chain: r.chains[chainNamePrerouting],
Exprs: exprs,
UserData: []byte(ruleKey),
})
return nil
}
// addPostroutingRules adds the masquerade rules
func (r *router) addPostroutingRules() error {
// First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{
// Match on the first fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
})
// Second masquerade rule for traffic going out through WireGuard interface
exprs2 := []expr.Any{
// Match on the second fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
},
// Match WireGuard interface
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2,
})
return nil
}
@@ -723,18 +828,18 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
return nberrors.FormatErrorOrNil(merr)
}
// RemoveNatRule removes a nftables rule pair from nat chains
// RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
return fmt.Errorf("remove prerouting rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
return fmt.Errorf("remove inverse prerouting rule: %w", err)
}
if err := r.removeLegacyRouteRule(pair); err != nil {
@@ -749,21 +854,20 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return nil
}
// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err)
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination)
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: nat rule %s not found", ruleKey)
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
}
return nil

View File

@@ -10,6 +10,7 @@ import (
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -32,100 +33,87 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(table, ifaceMock)
require.NoError(t, err, "failed to create router")
require.NoError(t, manager.init(table))
// need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock)
t.Cleanup(func() {
require.NoError(t, manager.Reset(nil))
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
nftablesTestingClient := &nftables.Conn{}
defer func(manager *router) {
require.NoError(t, manager.Reset(), "failed to reset rules")
}(manager)
require.NoError(t, err, "shouldn't return error")
err = manager.AddNatRule(testCase.InputPair)
rtr := manager.router
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "pair should be inserted")
defer func(manager *router, pair firewall.RouterPair) {
require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule")
}(manager, testCase.InputPair)
t.Cleanup(func() {
require.NoError(t, rtr.RemoveNatRule(testCase.InputPair), "failed to remove rule")
})
if testCase.InputPair.Masquerade {
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
testingExpression = append(testingExpression,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
// Build expected expressions for connection tracking
conntrackExprs := []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
// Build interface matching expression
ifaceExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
if testCase.InputPair.Masquerade {
// Build CIDR matching expressions
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
testingExpression = append(testingExpression,
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
)
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
// Combine all expressions in the correct order
// nolint:gocritic
testingExpression := append(conntrackExprs, ifaceExprs...)
testingExpression = append(testingExpression, sourceExp...)
testingExpression = append(testingExpression, destExp...)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
found = 1
for _, chain := range rtr.chains {
if chain.Name == chainNamePrerouting {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
// Compare expressions up to the mark setting expressions
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
require.Equal(t, 1, found, "should find at least 1 rule in prerouting chain")
}
})
}
}
@@ -135,68 +123,66 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(table, ifaceMock)
require.NoError(t, err, "failed to create router")
require.NoError(t, manager.init(table))
nftablesTestingClient := &nftables.Conn{}
defer func(manager *router) {
require.NoError(t, manager.Reset(), "failed to reset rules")
}(manager)
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(natRuleKey),
manager, err := Create(ifaceMock)
t.Cleanup(func() {
require.NoError(t, manager.Reset(nil))
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination)
rtr := manager.router
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
// First add the NAT rule using the router's method
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule")
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(inNatRuleKey),
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
}
// Verify the rule was added
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := false
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
}
}
require.True(t, found, "NAT rule should exist before removal")
// Now remove the rule
err = rtr.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error when removing rule")
// Verify the rule was removed
found = false
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules after removal")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
}
}
require.False(t, found, "NAT rule should not exist after removal")
// Verify the static postrouting rules still exist
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameRoutingNat])
require.NoError(t, err, "should list postrouting rules")
foundCounter := false
for _, rule := range rules {
for _, e := range rule.Exprs {
if _, ok := e.(*expr.Counter); ok {
foundCounter = true
break
}
}
if foundCounter {
break
}
}
require.True(t, foundCounter, "static postrouting rule should remain")
})
}
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
)
@@ -24,8 +25,8 @@ type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
}
// ICEBind is a bind implementation with two main features:
@@ -154,7 +155,7 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
return nil
}
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
@@ -166,16 +167,30 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
defer ipv4MsgsPool.Put(msgs)
msgs := getMessages(msgsPool)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer putMessages(msgs, msgsPool)
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
//nolint
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = wgConn.SplitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
if err != nil {
return 0, err
}
} else {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
}
} else {
msg := &(*msgs)[0]
@@ -191,11 +206,12 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
sizes[i] = 0
} else {
sizes[i] = msg.N
continue
}
sizes[i] = msg.N
if sizes[i] == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
@@ -273,3 +289,15 @@ func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
}
return newAddr, nil
}
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
return msgsPool.Get().(*[]ipv6.Message)
}
func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
for i := range *msgs {
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
}
msgsPool.Put(msgs)
}

View File

@@ -2,6 +2,7 @@ package bind
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
@@ -94,7 +95,10 @@ func (p *ProxyBind) close() error {
p.Bind.RemoveEndpoint(p.wgAddr)
return p.remoteConn.Close()
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr
}
return nil
}
func (p *ProxyBind) proxyToLocal(ctx context.Context) {

View File

@@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error {
e.cancel()
if err := e.remoteConn.Close(); err != nil {
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
}
return nil

View File

@@ -116,7 +116,7 @@ func (p *WGUDPProxy) close() error {
p.cancel()
var result *multierror.Error
if err := p.remoteConn.Close(); err != nil {
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
}

View File

@@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil {
return nil, err
}
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
return cfg, err
}
@@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
// WriteOutConfig write put the prepared config to the given path
func WriteOutConfig(path string, config *Config) error {
return util.WriteJson(path, config)
return util.WriteJson(context.Background(), path, config)
}
// createNewConfig creates a new config generating a new Wireguard key and saving to file
@@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
}
if updated {
if err := util.WriteJson(input.ConfigPath, config); err != nil {
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
return nil, err
}
}

View File

@@ -157,7 +157,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
engineCtx, cancel := context.WithCancel(c.ctx)
defer func() {
c.statusRecorder.MarkManagementDisconnected(state.err)
_, err := state.Status()
c.statusRecorder.MarkManagementDisconnected(err)
c.statusRecorder.CleanLocalPeerState()
cancel()
}()
@@ -207,7 +208,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
c.statusRecorder.MarkSignalDisconnected(nil)
defer func() {
c.statusRecorder.MarkSignalDisconnected(state.err)
_, err := state.Status()
c.statusRecorder.MarkSignalDisconnected(err)
}()
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal

View File

@@ -7,7 +7,6 @@ import (
"runtime"
"strings"
"sync"
"time"
"github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2"
@@ -323,12 +322,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
log.Error(err)
}
// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}
go func() {
// persist dns state right away
if err := s.stateManager.PersistState(s.ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}
}()
if s.searchDomainNotifier != nil {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
@@ -533,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks(
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
}
// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
l.Errorf("Failed to persist dns state: %v", err)
}
go func() {
if err := s.stateManager.PersistState(s.ctx); err != nil {
l.Errorf("Failed to persist dns state: %v", err)
}
}()
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone()

View File

@@ -782,7 +782,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
Port: 53,
},
},
Domains: []string{"customdomain.com"},
Domains: []string{"google.com"},
Primary: false,
},
},
@@ -804,7 +804,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
_, err = resolver.LookupHost(context.Background(), "google.com")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}

View File

@@ -11,6 +11,7 @@ import (
"reflect"
"runtime"
"slices"
"sort"
"strings"
"sync"
"sync/atomic"
@@ -38,7 +39,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
@@ -171,7 +171,7 @@ type Engine struct {
relayManager *relayClient.Manager
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
srWatcher *guard.SRWatcher
}
// Peer is an instance of the Connection Peer
@@ -297,7 +297,7 @@ func (e *Engine) Stop() error {
if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err)
}
if err := e.stateManager.PersistState(ctx); err != nil {
if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
@@ -641,6 +641,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
}
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if e.wgInterface == nil {
return errors.New("wireguard interface is not initialized")
}
if e.wgInterface.Address().String() != conf.Address {
oldAddr := e.wgInterface.Address().String()
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
@@ -1481,6 +1485,17 @@ func (e *Engine) stopDNSServer() {
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
for _, check := range checks {
sort.Slice(check.Files, func(i, j int) bool {
return check.Files[i] < check.Files[j]
})
}
for _, oCheck := range oChecks {
sort.Slice(oCheck.Files, func(i, j int) bool {
return oCheck.Files[i] < oCheck.Files[j]
})
}
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})

View File

@@ -1006,6 +1006,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
}
}
func Test_CheckFilesEqual(t *testing.T) {
testCases := []struct {
name string
inputChecks1 []*mgmtProto.Checks
inputChecks2 []*mgmtProto.Checks
expectedBool bool
}{
{
name: "Equal Files In Equal Order Should Return True",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
expectedBool: true,
},
{
name: "Equal Files In Reverse Order Should Return True",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile2",
"testfile1",
},
},
},
expectedBool: true,
},
{
name: "Unequal Files Should Return False",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile3",
},
},
},
expectedBool: false,
},
{
name: "Compared With Empty Should Return False",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{},
},
},
expectedBool: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
})
}
}
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {

View File

@@ -442,7 +442,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
if conn.iceP2PIsActive() {
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
conn.wgProxyRelay = wgProxy
conn.setRelayedProxy(wgProxy)
conn.statusRelay.Set(StatusConnected)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
return
@@ -465,7 +465,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
wgConfigWorkaround()
conn.currentConnPriority = connPriorityRelay
conn.statusRelay.Set(StatusConnected)
conn.wgProxyRelay = wgProxy
conn.setRelayedProxy(wgProxy)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
@@ -736,6 +736,15 @@ func (conn *Conn) logTraceConnState() {
}
}
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyRelay = proxy
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}

View File

@@ -67,7 +67,7 @@ func (s *State) DeleteRoute(network string) {
func (s *State) GetRoutes() map[string]struct{} {
s.Mux.RLock()
defer s.Mux.RUnlock()
return s.routes
return maps.Clone(s.routes)
}
// LocalPeerState contains the latest state of the local peer
@@ -237,10 +237,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP
}
if receivedState.GetRoutes() != nil {
peerState.SetRoutes(receivedState.GetRoutes())
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
if receivedState.ConnStatus != peerState.ConnStatus {
@@ -261,12 +257,40 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
d.notifyPeerListChanged()
return nil
}
func (d *Status) AddPeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer]
if !ok {
return errors.New("peer doesn't exist")
}
peerState.AddRoute(route)
d.peers[peer] = peerState
// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
return nil
}
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer]
if !ok {
return errors.New("peer doesn't exist")
}
peerState.DeleteRoute(route)
d.peers[peer] = peerState
// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
return nil
}
@@ -301,12 +325,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
@@ -334,12 +353,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
@@ -366,12 +380,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
@@ -401,12 +410,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
return nil
}
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
@@ -477,11 +481,14 @@ func (d *Status) FinishPeerListModifications() {
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
d.mux.Lock()
defer d.mux.Unlock()
ch, found := d.changeNotify[peer]
if !found || ch == nil {
ch = make(chan struct{})
d.changeNotify[peer] = ch
if found {
return ch
}
ch = make(chan struct{})
d.changeNotify[peer] = ch
return ch
}
@@ -755,6 +762,17 @@ func (d *Status) onConnectionChanged() {
d.notifier.updateServerStates(d.managementState, d.signalState)
}
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
ch, found := d.changeNotify[peerID]
if !found {
return
}
close(ch)
delete(d.changeNotify, peerID)
}
func (d *Status) notifyPeerListChanged() {
d.notifier.peerListChanged(d.numOfPeers())
}

View File

@@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
peerState.IP = ip
err := status.UpdatePeerState(peerState)
err := status.UpdatePeerRelayedStateToDisconnected(peerState)
assert.NoError(t, err, "shouldn't return error")
select {

View File

@@ -57,6 +57,9 @@ type WorkerICE struct {
localUfrag string
localPwd string
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
lastKnownState ice.ConnectionState
}
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
@@ -194,8 +197,7 @@ func (w *WorkerICE) Close() {
return
}
err := w.agent.Close()
if err != nil {
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
}
@@ -215,15 +217,18 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
w.conn.OnStatusChanged(StatusDisconnected)
w.muxAgent.Lock()
agentCancel()
_ = agent.Close()
w.agent = nil
w.muxAgent.Unlock()
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState != ice.ConnectionStateDisconnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.OnStatusChanged(StatusDisconnected)
}
w.closeAgent(agentCancel)
default:
return
}
})
if err != nil {
@@ -249,6 +254,17 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
return agent, nil
}
func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
cancel()
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
w.agent = nil
}
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
// wait local endpoint configuration
time.Sleep(time.Second)

View File

@@ -122,13 +122,20 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
tempScore = float64(metricDiff) * 10
}
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
latency := time.Second
// in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route
latency := 999 * time.Millisecond
if peerStatus.latency != 0 {
latency = peerStatus.latency
} else {
log.Warnf("peer %s has 0 latency", r.Peer)
log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler)
}
// avoid negative tempScore on the higher latency calculation
if latency > 1*time.Second {
latency = 999 * time.Millisecond
}
// higher latency is worse score
tempScore += 1 - latency.Seconds()
if !peerStatus.relayed {
@@ -150,6 +157,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
}
}
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)
switch {
case chosen == "":
var peers []string
@@ -195,15 +204,20 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
func (c *clientNetwork) startPeersStatusChangeWatcher() {
for _, r := range c.routes {
_, found := c.routePeersNotifiers[r.Peer]
if !found {
c.routePeersNotifiers[r.Peer] = make(chan struct{})
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer])
if found {
continue
}
closerChan := make(chan struct{})
c.routePeersNotifiers[r.Peer] = closerChan
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
}
}
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
c.removeStateRoute()
func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
if err := c.handler.RemoveAllowedIPs(); err != nil {
return fmt.Errorf("remove allowed IPs: %w", err)
@@ -218,7 +232,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
var merr *multierror.Error
if err := c.removeRouteFromWireguardPeer(); err != nil {
if err := c.removeRouteFromWireGuardPeer(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
}
if err := c.handler.RemoveRoute(); err != nil {
@@ -257,7 +271,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
}
} else {
// Otherwise, remove the allowed IPs from the previous peer first
if err := c.removeRouteFromWireguardPeer(); err != nil {
if err := c.removeRouteFromWireGuardPeer(); err != nil {
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
}
@@ -268,37 +282,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
c.addStateRoute()
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
if err != nil {
return fmt.Errorf("add peer state route: %w", err)
}
return nil
}
func (c *clientNetwork) addStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.AddRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) removeStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.DeleteRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
go func() {
c.routeUpdate <- update

View File

@@ -1,6 +1,7 @@
package routemanager
import (
"fmt"
"net/netip"
"testing"
"time"
@@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "relayed routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: true,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
relayed: true,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "p2p routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "current route with bad score should be changed to route with better score",
statuses: map[route.ID]routerPeerStatus{
@@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
},
}
// fill the test data with random routes
for _, tc := range testCases {
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
dummyStatus := routerPeerStatus{
connected: false,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
dummyStatus := routerPeerStatus{
connected: false,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
currentRoute := &route.Route{

View File

@@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
type Counter[Key comparable, I, O any] struct {
// refCountMap keeps track of the reference Ref for keys
refCountMap map[Key]Ref[O]
refCountMu sync.Mutex
mu sync.Mutex
// idMap keeps track of the keys associated with an ID for removal
idMap map[string][]Key
idMu sync.Mutex
add AddFunc[Key, I, O]
remove RemoveFunc[Key, O]
}
@@ -75,10 +74,8 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
func (rm *Counter[Key, I, O]) LoadData(
existingCounter *Counter[Key, I, O],
) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
rm.refCountMap = existingCounter.refCountMap
rm.idMap = existingCounter.idMap
@@ -87,8 +84,8 @@ func (rm *Counter[Key, I, O]) LoadData(
// Get retrieves the current reference count and associated data for a key.
// If the key doesn't exist, it returns a zero value Ref and false.
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
ref, ok := rm.refCountMap[key]
return ref, ok
@@ -97,9 +94,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
// Increment increments the reference count for the given key.
// If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
return rm.increment(key, in)
}
func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
ref := rm.refCountMap[key]
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
@@ -126,10 +127,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
// If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
ref, err := rm.Increment(key, in)
ref, err := rm.increment(key, in)
if err != nil {
return ref, fmt.Errorf("with ID: %w", err)
}
@@ -141,9 +142,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O],
// Decrement decrements the reference count for the given key.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
return rm.decrement(key)
}
func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
ref, ok := rm.refCountMap[key]
if !ok {
logCallerF("No reference found for key %v", key)
@@ -168,12 +172,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
// DecrementWithID decrements the reference count for all keys associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
var merr *multierror.Error
for _, key := range rm.idMap[id] {
if _, err := rm.Decrement(key); err != nil {
if _, err := rm.decrement(key); err != nil {
merr = multierror.Append(merr, err)
}
}
@@ -184,10 +188,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
// Flush removes all references and calls RemoveFunc for each key.
func (rm *Counter[Key, I, O]) Flush() error {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
var merr *multierror.Error
for key := range rm.refCountMap {
@@ -206,10 +208,8 @@ func (rm *Counter[Key, I, O]) Flush() error {
// Clear removes all references without calling RemoveFunc.
func (rm *Counter[Key, I, O]) Clear() {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
clear(rm.refCountMap)
clear(rm.idMap)
@@ -217,6 +217,9 @@ func (rm *Counter[Key, I, O]) Clear() {
// MarshalJSON implements the json.Marshaler interface for Counter.
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
rm.mu.Lock()
defer rm.mu.Unlock()
return json.Marshal(struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
IDMap map[string][]Key `json:"idMap"`

View File

@@ -2,31 +2,28 @@ package systemops
import (
"net/netip"
"sync"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
type ShutdownState struct {
Counter *ExclusionCounter `json:"counter,omitempty"`
mu sync.RWMutex
}
type ShutdownState ExclusionCounter
func (s *ShutdownState) Name() string {
return "route_state"
}
func (s *ShutdownState) Cleanup() error {
s.mu.RLock()
defer s.mu.RUnlock()
if s.Counter == nil {
return nil
}
sysops := NewSysOps(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
sysops.refCounter.LoadData(s.Counter)
sysops.refCounter.LoadData((*ExclusionCounter)(s))
return sysops.refCounter.Flush()
}
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
return (*ExclusionCounter)(s).MarshalJSON()
}
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
return (*ExclusionCounter)(s).UnmarshalJSON(data)
}

View File

@@ -57,30 +57,19 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return nexthop, refcounter.ErrIgnore
}
r.updateState(stateManager)
return nexthop, err
},
func(prefix netip.Prefix, nexthop Nexthop) error {
// remove from state even if we have trouble removing it from the route table
// it could be already gone
r.updateState(stateManager)
return r.removeFromRouteTable(prefix, nexthop)
},
r.removeFromRouteTable,
)
r.refCounter = refCounter
return r.setupHooks(initAddresses)
return r.setupHooks(initAddresses, stateManager)
}
// updateState updates state on every change so it will be persisted regularly
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
state := getState(stateManager)
state.Counter = r.refCounter
if err := stateManager.UpdateState(state); err != nil {
if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
@@ -336,7 +325,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
return r.removeFromRouteTable(prefix, nextHop)
}
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
@@ -347,6 +336,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
return fmt.Errorf("adding route reference: %v", err)
}
r.updateState(stateManager)
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
@@ -354,6 +345,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
return fmt.Errorf("remove route reference: %w", err)
}
r.updateState(stateManager)
return nil
}
@@ -532,14 +525,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
// Return true if the longest matching prefix is from vpnRoutes
return isVpn, longestPrefix
}
func getState(stateManager *statemanager.Manager) *ShutdownState {
var shutdownState *ShutdownState
if state := stateManager.GetState(shutdownState); state != nil {
shutdownState = state.(*ShutdownState)
} else {
shutdownState = &ShutdownState{}
}
return shutdownState
}

View File

@@ -55,7 +55,7 @@ type ruleParams struct {
// isLegacy determines whether to use the legacy routing setup
func isLegacy() bool {
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true"
}
// setIsLegacy sets the legacy routing setup

View File

@@ -16,6 +16,7 @@ import (
"golang.org/x/exp/maps"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/util"
)
// State interface defines the methods that all state types must implement
@@ -73,15 +74,15 @@ func (m *Manager) Stop(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel != nil {
m.cancel()
if m.cancel == nil {
return nil
}
m.cancel()
select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
}
return nil
@@ -178,25 +179,18 @@ func (m *Manager) PersistState(ctx context.Context) error {
return nil
}
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
bs, err := marshalWithPanicRecovery(m.states)
if err != nil {
return fmt.Errorf("marshal states: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
done := make(chan error, 1)
start := time.Now()
go func() {
data, err := json.MarshalIndent(m.states, "", " ")
if err != nil {
done <- fmt.Errorf("marshal states: %w", err)
return
}
// nolint:gosec
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
done <- fmt.Errorf("write state file: %w", err)
return
}
done <- nil
done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
}()
select {
@@ -208,7 +202,7 @@ func (m *Manager) PersistState(ctx context.Context) error {
}
}
log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty))
log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
clear(m.dirty)
@@ -296,3 +290,19 @@ func (m *Manager) PerformCleanup() error {
return nberrors.FormatErrorOrNil(merr)
}
func marshalWithPanicRecovery(v any) ([]byte, error) {
var bs []byte
var err error
func() {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during marshal: %v", r)
}
}()
bs, err = json.Marshal(v)
}()
return bs, err
}

View File

@@ -4,32 +4,20 @@ import (
"os"
"path/filepath"
"runtime"
log "github.com/sirupsen/logrus"
)
// GetDefaultStatePath returns the path to the state file based on the operating system
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist.
// It returns an empty string if the path cannot be determined.
func GetDefaultStatePath() string {
var path string
switch runtime.GOOS {
case "windows":
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
case "darwin", "linux":
path = "/var/lib/netbird/state.json"
return "/var/lib/netbird/state.json"
case "freebsd", "openbsd", "netbsd", "dragonfly":
path = "/var/db/netbird/state.json"
// ios/android don't need state
default:
return ""
return "/var/db/netbird/state.json"
}
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
return ""
}
return ""
return path
}

View File

@@ -1,13 +1,41 @@
package main
import (
"net/http"
_ "net/http/pprof"
"os"
"path/filepath"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/cmd"
)
func main() {
// only if env is not set
if os.Getenv("NB_LOG_LEVEL") == "" {
if err := os.Setenv("NB_LOG_LEVEL", "debug"); err != nil {
log.Errorf("Failed setting log-level: %v", err)
}
}
if err := os.Setenv("NB_LOG_MAX_SIZE_MB", "100"); err != nil {
log.Errorf("Failed setting log-size: %v", err)
}
if err := os.Setenv("NB_WINDOWS_PANIC_LOG", filepath.Join(os.Getenv("ProgramData"), "netbird", "netbird.err")); err != nil {
log.Errorf("Failed setting panic log path: %v", err)
}
go startPprofServer()
if err := cmd.Execute(); err != nil {
os.Exit(1)
}
}
func startPprofServer() {
pprofAddr := "localhost:6969"
log.Infof("Starting pprof debugging server on %s", pprofAddr)
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
log.Infof("pprof server failed: %v", err)
}
}

View File

@@ -120,6 +120,35 @@ func (s *Server) Start() error {
ctx, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if statusResp, err := s.Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true}); err != nil {
log.Infof("Error getting status: %v", err)
} else if statusResp.FullStatus != nil {
log.Infof("Status --------")
for _, peer := range statusResp.FullStatus.Peers {
log.Infof("[Peer Connection] Name: %s, IP: %s, Key: %s, Connection Status: %s, Relayed: %v, RelayedAddress: %v, Last WireGuard Handshake: %v",
peer.Fqdn,
peer.IP,
peer.PubKey,
peer.ConnStatus,
peer.Relayed,
peer.RelayAddress,
peer.LastWireguardHandshake.AsTime().Format("15:04:05"),
)
}
}
}
}
}()
// if configuration exists, we just start connections. if is new config we skip and set status NeedsLogin
// on failure we return error to retry
config, err := internal.UpdateConfig(s.latestConfigInput)
@@ -626,6 +655,8 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
s.mutex.Lock()
defer s.mutex.Unlock()
s.oauthAuthFlow = oauthAuthFlow{}
if s.actCancel == nil {
return nil, fmt.Errorf("service is not up")
}

4
go.mod
View File

@@ -60,7 +60,7 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
@@ -236,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

8
go.sum
View File

@@ -521,14 +521,14 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 h1:L8mNd3tBxMdnQNxMNJ+/EiwHwizNOMy8/nHLVGNfjpg=
github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g=
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=

View File

@@ -110,7 +110,6 @@ type AccountManager interface {
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
@@ -966,7 +965,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro
}
// UserGroupsAddToPeers adds groups to all peers of user
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string {
groupUpdates := make(map[string][]string)
userPeers := make(map[string]struct{})
for pid, peer := range a.Peers {
if peer.UserID == userID {
@@ -980,6 +981,8 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
continue
}
oldPeers := group.Peers
groupPeers := make(map[string]struct{})
for _, pid := range group.Peers {
groupPeers[pid] = struct{}{}
@@ -993,16 +996,25 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
groupUpdates[gid] = difference(group.Peers, oldPeers)
}
return groupUpdates
}
// UserGroupsRemoveFromPeers removes groups from all peers of user
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string {
groupUpdates := make(map[string][]string)
for _, gid := range groups {
group, ok := a.Groups[gid]
if !ok || group.Name == "All" {
continue
}
oldPeers := group.Peers
update := make([]string, 0, len(group.Peers))
for _, pid := range group.Peers {
peer, ok := a.Peers[pid]
@@ -1014,7 +1026,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
}
}
group.Peers = update
groupUpdates[gid] = difference(oldPeers, group.Peers)
}
return groupUpdates
}
// BuildManager creates a new DefaultAccountManager with a provided Store
@@ -1176,6 +1191,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, err
}
err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
if err != nil {
return nil, fmt.Errorf("groups propagation failed: %w", err)
}
updatedAccount := account.UpdateSettings(newSettings)
err = am.Store.SaveAccount(ctx, account)
@@ -1186,21 +1206,39 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return updatedAccount, nil
}
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
event := activity.AccountPeerInactivityExpirationEnabled
if !newSettings.PeerInactivityExpirationEnabled {
event = activity.AccountPeerInactivityExpirationDisabled
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error {
if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled {
if newSettings.GroupsPropagationEnabled {
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil)
// Todo: retroactively add user groups to all peers
} else {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil)
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
return nil
}
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled {
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
} else {
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
event := activity.AccountPeerInactivityExpirationEnabled
if !newSettings.PeerInactivityExpirationEnabled {
event = activity.AccountPeerInactivityExpirationDisabled
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
} else {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
}
return nil
@@ -1249,7 +1287,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
log.Errorf("failed getting account %s expiring peers", account.Id)
log.Errorf("failed getting account %s expiring peers", accountID)
return account.GetNextInactivePeerExpiration()
}
@@ -1435,7 +1473,7 @@ func isNil(i idp.Manager) bool {
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) {
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
@@ -2029,7 +2067,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error getting user: %w", err)
}
groups, err := transaction.GetAccountGroups(ctx, accountID)
groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
@@ -2059,7 +2097,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
groups, err = transaction.GetAccountGroups(ctx, accountID)
groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
@@ -2083,7 +2121,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error saving groups: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err)
}
}
@@ -2101,7 +2139,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
@@ -2114,7 +2152,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
@@ -2127,14 +2165,19 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
if settings.GroupsPropagationEnabled {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
return err
}
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
if err != nil {
return err
}
if removedGroupAffectsPeers || newGroupsAffectsPeers {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
}
@@ -2290,12 +2333,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, status.NewGetAccountError(err)
}
peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
}
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account)
@@ -2314,12 +2357,12 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
return status.NewGetAccountError(err)
}
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
}
return nil
@@ -2335,6 +2378,9 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer unlock()
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer unlockPeer()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
@@ -2398,12 +2444,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context,
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
updatedAccount, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
return
}
am.updateAccountPeers(ctx, updatedAccount)
am.updateAccountPeers(ctx, accountID)
}
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {

View File

@@ -29,14 +29,18 @@ import (
)
type MocIntegratedValidator struct {
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
}
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil
}
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
return update, nil
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
if a.ValidatePeerFunc != nil {
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
}
return update, false, nil
}
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{})
@@ -978,6 +982,110 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
}
}
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
claims := jwtclaims.AuthorizationClaims{
Domain: "example.com",
UserId: "pvt-domain-user",
DomainCategory: PrivateCategory,
}
publicClaims := jwtclaims.AuthorizationClaims{
Domain: "test.com",
UserId: "public-domain-user",
DomainCategory: PublicCategory,
}
am, err := createManager(b)
if err != nil {
b.Fatal(err)
return
}
id, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil {
b.Fatal(err)
}
pid, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
if err != nil {
b.Fatal(err)
}
users := genUsers("priv", 100)
acc, err := am.Store.GetAccount(context.Background(), id)
if err != nil {
b.Fatal(err)
}
acc.Users = users
err = am.Store.SaveAccount(context.Background(), acc)
if err != nil {
b.Fatal(err)
}
userP := genUsers("pub", 100)
pacc, err := am.Store.GetAccount(context.Background(), pid)
if err != nil {
b.Fatal(err)
}
pacc.Users = userP
err = am.Store.SaveAccount(context.Background(), pacc)
if err != nil {
b.Fatal(err)
}
b.Run("public without account ID", func(b *testing.B) {
//b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("private without account ID", func(b *testing.B) {
//b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("private with account ID", func(b *testing.B) {
claims.AccountId = id
//b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil {
b.Fatal(err)
}
}
})
}
func genUsers(p string, n int) map[string]*User {
users := map[string]*User{}
now := time.Now()
for i := 0; i < n; i++ {
users[fmt.Sprintf("%s-%d", p, i)] = &User{
Id: fmt.Sprintf("%s-%d", p, i),
Role: UserRoleAdmin,
LastLogin: now,
CreatedAt: now,
Issued: "api",
AutoGroups: []string{"one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"},
}
}
return users
}
func TestAccountManager_AddPeer(t *testing.T) {
manager, err := createManager(t)
if err != nil {
@@ -1305,11 +1413,13 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
group := group.Group{
err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}
})
require.NoError(t, err, "failed to save group")
policy := Policy{
Enabled: true,
@@ -1352,7 +1462,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
return
}
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil {
t.Errorf("delete group: %v", err)
return
}
@@ -2606,7 +2716,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})
@@ -2626,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})
@@ -2665,7 +2775,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added")

View File

@@ -148,6 +148,9 @@ const (
AccountPeerInactivityExpirationDurationUpdated Activity = 67
SetupKeyDeleted Activity = 68
UserGroupPropagationEnabled Activity = 69
UserGroupPropagationDisabled Activity = 70
)
var activityMap = map[Activity]Code{
@@ -222,6 +225,9 @@ var activityMap = map[Activity]Code{
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"},
UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"},
UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"},
}
// StringCode returns a string code of the activity

View File

@@ -146,7 +146,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
}
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil

View File

@@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
// It is recommended to call it with locking FileStore.mux
func (s *FileStore) persist(ctx context.Context, file string) error {
start := time.Now()
err := util.WriteJson(file, s)
err := util.WriteJson(context.Background(), file, s)
if err != nil {
return err
}

View File

@@ -6,11 +6,12 @@ import (
"fmt"
"slices"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status"
@@ -27,18 +28,17 @@ func (e *GroupLinkError) Error() string {
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
return nil
@@ -49,8 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
}
// GetAllGroups returns all groups in an account
@@ -58,13 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetAccountGroups(ctx, accountID)
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
}
// SaveGroup object of the peers
@@ -77,79 +75,74 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
// SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID)
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var eventsToStore []func()
var groupsToSave []*nbgroup.Group
var updateAccountPeers bool
for _, newGroup := range newGroups {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
// Avoid duplicate groups only for the API issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.AccountID = accountID
groupsToSave = append(groupsToSave, newGroup)
groupIDs = append(groupIDs, newGroup.ID)
newGroup.ID = xid.New().String()
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
}
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
if err != nil {
return err
}
oldGroup := account.Groups[newGroup.ID]
account.Groups[newGroup.ID] = newGroup
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
eventsToStore = append(eventsToStore, events...)
}
newGroupIDs := make([]string, 0, len(newGroups))
for _, newGroup := range newGroups {
newGroupIDs = append(newGroupIDs, newGroup.ID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
})
if err != nil {
return err
}
if areGroupChangesAffectPeers(account, newGroupIDs) {
am.updateAccountPeers(ctx, account)
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
// prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
var eventsToStore []func()
addedPeers := make([]string, 0)
removedPeers := make([]string, 0)
if oldGroup != nil {
oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
if err == nil && oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else {
@@ -159,35 +152,42 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
})
}
for _, p := range addedPeers {
peer := account.Peers[p]
if peer == nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
modifiedPeers := slices.Concat(addedPeers, removedPeers)
peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers)
if err != nil {
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
return nil
}
for _, peerID := range addedPeers {
peer, ok := peers[peerID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID)
continue
}
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
meta := map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID,
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
}
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta)
})
}
for _, p := range removedPeers {
peer := account.Peers[p]
if peer == nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
for _, peerID := range removedPeers {
peer, ok := peers[peerID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID)
continue
}
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
meta := map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID,
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
}
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta)
})
}
@@ -210,42 +210,10 @@ func difference(a, b []string) []string {
}
// DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return nil
}
allGroup, err := account.GetGroupAll()
if err != nil {
return err
}
if allGroup.ID == groupID {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if err = validateDeleteGroup(account, group, userId); err != nil {
return err
}
delete(account.Groups, groupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
return nil
return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
}
// DeleteGroups deletes groups from an account.
@@ -254,93 +222,94 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
//
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountId)
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var allErrors error
var groupIDsToDelete []string
var deletedGroups []*nbgroup.Group
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
for _, groupID := range groupIDs {
group, ok := account.Groups[groupID]
if !ok {
continue
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
if err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
groupIDsToDelete = append(groupIDsToDelete, groupID)
deletedGroups = append(deletedGroups, group)
}
if err := validateDeleteGroup(account, group, userId); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
continue
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
delete(account.Groups, groupID)
deletedGroups = append(deletedGroups, group)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
})
if err != nil {
return err
}
for _, g := range deletedGroups {
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
for _, group := range deletedGroups {
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
}
return allErrors
}
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
}
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
var group *nbgroup.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.AddPeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
})
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
add := true
for _, itemID := range group.Peers {
if itemID == peerID {
add = false
break
}
}
if add {
group.Peers = append(group.Peers, peerID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -351,90 +320,162 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
var group *nbgroup.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.RemovePeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
})
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
account.Network.IncSerial()
for i, itemID := range group.Peers {
if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(ctx, account); err != nil {
return err
}
}
}
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
// validateNewGroup validates the new group for existence and required fields.
func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
if err != nil {
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
return err
}
}
// Prevent duplicate groups for API-issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
return nil
}
func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser := account.Users[userID]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
if group.IsGroupAll() {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name}
}
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name}
}
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name}
}
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id}
}
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
return checkGroupLinkedToSettings(ctx, transaction, group)
}
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name}
}
if account.Settings.Extra != nil {
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
}
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
}
return nil
}
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
}
for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r
}
}
return false, nil
}
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
}
for _, policy := range policies {
for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
@@ -446,7 +487,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
}
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
}
for _, dns := range nameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
@@ -454,11 +501,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
}
}
}
return false, nil
}
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
}
for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey
@@ -468,7 +522,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
}
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
}
for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) {
return true, user
@@ -477,6 +537,35 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
return false, nil
}
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return false, err
}
for _, groupID := range groupIDs {
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
return true, nil
}
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
return true, nil
}
}
return false, nil
}
// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
@@ -486,22 +575,3 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
}
return false
}
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
return true
}
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
return true
}
}
return false
}

View File

@@ -49,3 +49,35 @@ func (g *Group) Copy() *Group {
func (g *Group) HasPeers() bool {
return len(g.Peers) > 0
}
// IsGroupAll checks if the group is a default "All" group.
func (g *Group) IsGroupAll() bool {
return g.Name == "All"
}
// AddPeer adds peerID to Peers if not present, returning true if added.
func (g *Group) AddPeer(peerID string) bool {
if peerID == "" {
return false
}
for _, itemID := range g.Peers {
if itemID == peerID {
return false
}
}
g.Peers = append(g.Peers, peerID)
return true
}
// RemovePeer removes peerID from Peers if present, returning true if removed.
func (g *Group) RemovePeer(peerID string) bool {
for i, itemID := range g.Peers {
if itemID == peerID {
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
return true
}
}
return false
}

View File

@@ -0,0 +1,90 @@
package group
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAddPeer(t *testing.T) {
t.Run("add new peer to empty slice", func(t *testing.T) {
group := &Group{Peers: []string{}}
peerID := "peer1"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add new peer to nil slice", func(t *testing.T) {
group := &Group{Peers: nil}
peerID := "peer1"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add new peer to non-empty slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer3"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add duplicate peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer1"
assert.False(t, group.AddPeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
t.Run("add empty peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := ""
assert.False(t, group.AddPeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
}
func TestRemovePeer(t *testing.T) {
t.Run("remove existing peer from slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2", "peer3"}}
peerID := "peer2"
assert.True(t, group.RemovePeer(peerID))
assert.NotContains(t, group.Peers, peerID)
assert.Equal(t, 2, len(group.Peers))
})
t.Run("remove peer from empty slice", func(t *testing.T) {
group := &Group{Peers: []string{}}
peerID := "peer1"
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 0, len(group.Peers))
})
t.Run("remove peer from nil slice", func(t *testing.T) {
group := &Group{Peers: nil}
peerID := "peer1"
assert.False(t, group.RemovePeer(peerID))
assert.Nil(t, group.Peers)
})
t.Run("remove non-existent peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer3"
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
t.Run("remove peer from single-item slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1"}}
peerID := "peer1"
assert.True(t, group.RemovePeer(peerID))
assert.Equal(t, 0, len(group.Peers))
assert.NotContains(t, group.Peers, peerID)
})
t.Run("remove empty peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := ""
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
}

View File

@@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
{
name: "delete non-existent group",
groupIDs: []string{"non-existent-group"},
expectedDeleted: []string{"non-existent-group"},
expectedReasons: []string{"group: non-existent-group not found"},
},
{
name: "delete multiple groups with mixed results",

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"strings"
"sync"
"time"
pb "github.com/golang/protobuf/proto" // nolint
@@ -38,6 +39,7 @@ type GRPCServer struct {
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
peerLocks sync.Map
}
// NewServer creates a new Management server
@@ -148,6 +150,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
defer func() {
if unlock != nil {
unlock()
}
}()
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// nolint:staticcheck
@@ -171,6 +180,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
return mapError(ctx, err)
}
@@ -190,11 +200,15 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
}
unlock()
unlock = nil
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
for {
select {
// condition when there are some updates
@@ -245,10 +259,18 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
}
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock()
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
if err != nil {
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
}
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.secretsManager.CancelRefresh(peer.ID)
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
}
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
@@ -274,6 +296,24 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
return claims.UserId, nil
}
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
start := time.Now()
value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.Lock()
log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start))
start = time.Now()
unlock = func() {
mtx.Unlock()
log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start))
}
return unlock
}
// maps internal internalStatus.Error to gRPC status.Error
func mapError(ctx context.Context, err error) error {
if e, ok := internalStatus.FromError(err); ok {

View File

@@ -439,17 +439,13 @@ components:
example: 5
required:
- accessible_peers_count
SetupKey:
SetupKeyBase:
type: object
properties:
id:
description: Setup Key ID
type: string
example: 2531583362
key:
description: Setup Key value
type: string
example: A616097E-FCF0-48FA-9354-CA4A61142761
name:
description: Setup key name identifier
type: string
@@ -518,22 +514,31 @@ components:
- updated_at
- usage_limit
- ephemeral
SetupKeyClear:
allOf:
- $ref: '#/components/schemas/SetupKeyBase'
- type: object
properties:
key:
description: Setup Key as plain text
type: string
example: A616097E-FCF0-48FA-9354-CA4A61142761
required:
- key
SetupKey:
allOf:
- $ref: '#/components/schemas/SetupKeyBase'
- type: object
properties:
key:
description: Setup Key as secret
type: string
example: A6160****
required:
- key
SetupKeyRequest:
type: object
properties:
name:
description: Setup Key name
type: string
example: Default key
type:
description: Setup key type, one-off for single time usage and reusable
type: string
example: reusable
expires_in:
description: Expiration time in seconds, 0 will mean the key never expires
type: integer
minimum: 0
example: 86400
revoked:
description: Setup key revocation status
type: boolean
@@ -544,21 +549,9 @@ components:
items:
type: string
example: "ch8i4ug6lnn4g9hqv7m0"
usage_limit:
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
type: integer
example: 0
ephemeral:
description: Indicate that the peer will be ephemeral or not
type: boolean
example: true
required:
- name
- type
- expires_in
- revoked
- auto_groups
- usage_limit
CreateSetupKeyRequest:
type: object
properties:
@@ -1943,7 +1936,7 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/SetupKey'
$ref: '#/components/schemas/SetupKeyClear'
'400':
"$ref": "#/components/responses/bad_request"
'401':

View File

@@ -1062,7 +1062,94 @@ type SetupKey struct {
// Id Setup Key ID
Id string `json:"id"`
// Key Setup Key value
// Key Setup Key as secret
Key string `json:"key"`
// LastUsed Setup key last usage date
LastUsed time.Time `json:"last_used"`
// Name Setup key name identifier
Name string `json:"name"`
// Revoked Setup key revocation status
Revoked bool `json:"revoked"`
// State Setup key status, "valid", "overused","expired" or "revoked"
State string `json:"state"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UpdatedAt Setup key last update date
UpdatedAt time.Time `json:"updated_at"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
// UsedTimes Usage count of setup key
UsedTimes int `json:"used_times"`
// Valid Setup key validity status
Valid bool `json:"valid"`
}
// SetupKeyBase defines model for SetupKeyBase.
type SetupKeyBase struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral bool `json:"ephemeral"`
// Expires Setup Key expiration date
Expires time.Time `json:"expires"`
// Id Setup Key ID
Id string `json:"id"`
// LastUsed Setup key last usage date
LastUsed time.Time `json:"last_used"`
// Name Setup key name identifier
Name string `json:"name"`
// Revoked Setup key revocation status
Revoked bool `json:"revoked"`
// State Setup key status, "valid", "overused","expired" or "revoked"
State string `json:"state"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UpdatedAt Setup key last update date
UpdatedAt time.Time `json:"updated_at"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
// UsedTimes Usage count of setup key
UsedTimes int `json:"used_times"`
// Valid Setup key validity status
Valid bool `json:"valid"`
}
// SetupKeyClear defines model for SetupKeyClear.
type SetupKeyClear struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral bool `json:"ephemeral"`
// Expires Setup Key expiration date
Expires time.Time `json:"expires"`
// Id Setup Key ID
Id string `json:"id"`
// Key Setup Key as plain text
Key string `json:"key"`
// LastUsed Setup key last usage date
@@ -1098,23 +1185,8 @@ type SetupKeyRequest struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral *bool `json:"ephemeral,omitempty"`
// ExpiresIn Expiration time in seconds, 0 will mean the key never expires
ExpiresIn int `json:"expires_in"`
// Name Setup Key name
Name string `json:"name"`
// Revoked Setup key revocation status
Revoked bool `json:"revoked"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
}
// User defines model for User.

View File

@@ -184,14 +184,26 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain()
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
for _, peer := range account.Peers {
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupsMap := map[string]*nbgroup.Group{}
groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
for _, group := range groups {
groupsMap[group.ID] = group
}
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
}
@@ -304,7 +316,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
}
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum
groupsInfo := []api.GroupMinimum{}
groupsChecked := make(map[string]struct{})
for _, group := range groups {
_, ok := groupsChecked[group.ID]

View File

@@ -149,7 +149,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
}
if req.Peer == nil && req.PeerGroups == nil {
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided")
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided")
}
if req.Peer != nil && req.PeerGroups != nil {

View File

@@ -137,11 +137,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
return
}
if req.Name == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
return
}
if req.AutoGroups == nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
return
@@ -150,7 +145,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
newKey := &server.SetupKey{}
newKey.AutoGroups = req.AutoGroups
newKey.Revoked = req.Revoked
newKey.Name = req.Name
newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)

View File

@@ -52,25 +52,22 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
return am.Store.SaveAccount(ctx, a)
}
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
if len(groups) == 0 {
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return true, nil
}
accountsGroups, err := am.ListGroups(ctx, accountId)
if err != nil {
return false, err
}
for _, group := range groups {
var found bool
for _, accountGroup := range accountsGroups {
if accountGroup.ID == group {
found = true
break
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
for _, groupID := range groupIDs {
_, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
}
if !found {
return false, nil
}
return nil
})
if err != nil {
return false, err
}
return true, nil

View File

@@ -11,7 +11,7 @@ import (
// IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface {
ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error)

View File

@@ -453,8 +453,8 @@ func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtr
return nil
}
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
return update, nil
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, false, nil
}
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {

View File

@@ -45,7 +45,6 @@ type MockAccountManager struct {
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
@@ -354,14 +353,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
}
// ListGroups mock implementation of ListGroups from server.AccountManager interface
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
if am.ListGroupsFunc != nil {
return am.ListGroupsFunc(ctx, accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
}
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
if am.GroupAddPeerFunc != nil {

View File

@@ -71,7 +71,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
}
if anyGroupHasPeers(account, newNSGroup.Groups) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
@@ -106,7 +106,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
}
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
@@ -136,7 +136,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
}
if anyGroupHasPeers(account, nsGroup.Groups) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())

View File

@@ -110,14 +110,16 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error {
peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil {
return err
return fmt.Errorf("failed to find peer by pub key: %w", err)
}
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account)
if err != nil {
return err
return fmt.Errorf("failed to update peer status and location: %w", err)
}
log.WithContext(ctx).Debugf("mark peer %s connected: %t", peer.ID, connected)
if peer.AddedWithSSOLogin() {
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account)
@@ -131,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
if expired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account.Id)
}
return nil
@@ -166,9 +168,11 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
account.UpdatePeer(peer)
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
if err != nil {
return false, err
return false, fmt.Errorf("failed to save peer status: %w", err)
}
return oldStatus.LoginExpired, nil
@@ -189,7 +193,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID)
}
update, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
var requiresPeerUpdates bool
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil {
return nil, err
}
@@ -265,8 +270,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return nil, err
}
if peerLabelUpdated {
am.updateAccountPeers(ctx, account)
if peerLabelUpdated || requiresPeerUpdates {
am.updateAccountPeers(ctx, accountID)
}
return peer, nil
@@ -330,7 +335,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
updateAccountPeers := isPeerInActiveGroup(account, peerID)
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID)
if err != nil {
return err
}
err = am.deletePeers(ctx, account, []string{peerID}, userID)
if err != nil {
@@ -343,7 +351,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
}
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -550,7 +558,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return fmt.Errorf("failed to add peer to account: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
@@ -586,11 +594,22 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
return nil, nil, nil, status.NewGetAccountError(err)
}
if areGroupChangesAffectPeers(account, groupsToAdd) {
am.updateAccountPeers(ctx, account)
allGroup, err := account.GetGroupAll()
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
}
groupsToAdd = append(groupsToAdd, allGroup.ID)
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd)
if err != nil {
return nil, nil, nil, err
}
if newGroupsAffectsPeers {
am.updateAccountPeers(ctx, accountID)
}
approvedPeersMap, err := am.GetValidatedPeers(account)
@@ -633,7 +652,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
if peer.UserID != "" {
user, err := account.FindUser(peer.UserID)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, fmt.Errorf("failed to get user: %w", err)
}
err = checkIfPeerOwnerIsBlocked(peer, user)
@@ -648,19 +667,22 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
updated := peer.UpdateMetaIfNew(sync.Meta)
if updated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
account.Peers[peer.ID] = peer
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
err = am.Store.SavePeer(ctx, account.Id, peer)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err)
}
if sync.UpdateAccountPeers {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account.Id)
}
}
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, fmt.Errorf("failed to validate peer: %w", err)
}
var postureChecks []*posture.Checks
@@ -673,12 +695,12 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
}
if isStatusChanged {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account.Id)
}
validPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
}
postureChecks = am.getPeerPostureChecks(account, peer)
@@ -758,7 +780,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -780,6 +802,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
updated := peer.UpdateMetaIfNew(login.Meta)
if updated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
shouldStorePeer = true
}
@@ -804,7 +827,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer)
@@ -967,7 +990,13 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
// updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err)
return
}
start := time.Now()
defer func() {
if am.metrics != nil {
@@ -1021,12 +1050,12 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func isPeerInActiveGroup(account *Account, peerID string) bool {
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) {
peerGroupIDs := make([]string, 0)
for _, group := range account.Groups {
if slices.Contains(group.Peers, peerID) {
peerGroupIDs = append(peerGroupIDs, group.ID)
}
}
return areGroupChangesAffectPeers(account, peerGroupIDs)
return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs)
}

View File

@@ -22,6 +22,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -876,7 +877,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
start := time.Now()
for i := 0; i < b.N; i++ {
manager.updateAccountPeers(ctx, account)
manager.updateAccountPeers(ctx, account.Id)
}
duration := time.Since(start)
@@ -1398,6 +1399,50 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
}
})
t.Run("validator requires update", func(t *testing.T) {
requireUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *nbAccount.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, true, nil
}
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireUpdateFunc}
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
t.Run("validator requires no update", func(t *testing.T) {
requireNoUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *nbAccount.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, false, nil
}
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc}
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Adding peer to group linked with policy should update account peers and send peer update
t.Run("adding peer to group linked with policy", func(t *testing.T) {
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{

View File

@@ -377,7 +377,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -406,7 +406,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
if anyGroupHasPeers(account, policy.ruleGroups()) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil

View File

@@ -69,7 +69,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil

View File

@@ -238,7 +238,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
}
if isRouteChangeAffectPeers(account, &newRoute) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
@@ -324,7 +324,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
}
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
@@ -356,7 +356,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
if isRouteChangeAffectPeers(account, routy) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil

View File

@@ -1091,7 +1091,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.ListGroups(context.Background(), account.Id)
groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id)
require.NoError(t, err)
var groupHA1, groupHA2 *nbgroup.Group
for _, group := range groups {

View File

@@ -4,8 +4,8 @@ import (
"context"
"crypto/sha256"
b64 "encoding/base64"
"fmt"
"hash/fnv"
"slices"
"strconv"
"strings"
"time"
@@ -229,32 +229,43 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil {
return nil, err
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
account.SetupKeys[setupKey.Key] = setupKey
err = am.Store.SaveAccount(ctx, account)
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
var setupKey *SetupKey
var plainKey string
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil {
return err
}
setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
setupKey.AccountID = accountID
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey)
eventsToStore = append(eventsToStore, events...)
return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey)
})
if err != nil {
return nil, status.Errorf(status.Internal, "failed adding account key")
return nil, err
}
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta())
for _, g := range setupKey.AutoGroups {
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
// for the creation return the plain key to the caller
@@ -266,45 +277,61 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
// SaveSetupKey saves the provided SetupKey to the database overriding the existing one.
// Due to the unique nature of a SetupKey certain properties must not be overwritten
// (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
// These properties are overwritten: AutoGroups, Revoked (only from false to true), and the UpdatedAt. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if keyToSave == nil {
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
}
account, err := am.Store.GetAccount(ctx, accountID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
var oldKey *SetupKey
for _, key := range account.SetupKeys {
if key.Id == keyToSave.Id {
oldKey = key.Copy()
break
var newKey *SetupKey
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil {
return err
}
}
if oldKey == nil {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil {
return nil, err
}
oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id)
if err != nil {
return err
}
// only auto groups, revoked status, and name can be updated for now
newKey := oldKey.Copy()
newKey.Name = keyToSave.Name
newKey.AutoGroups = keyToSave.AutoGroups
newKey.Revoked = keyToSave.Revoked
newKey.UpdatedAt = time.Now().UTC()
if oldKey.Revoked && !keyToSave.Revoked {
return status.Errorf(status.InvalidArgument, "can't un-revoke a revoked setup key")
}
account.SetupKeys[newKey.Key] = newKey
// only auto groups, revoked status (from false to true) can be updated
newKey = oldKey.Copy()
newKey.AutoGroups = keyToSave.AutoGroups
newKey.Revoked = keyToSave.Revoked
newKey.UpdatedAt = time.Now().UTC()
if err = am.Store.SaveAccount(ctx, account); err != nil {
addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups)
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups)
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey)
eventsToStore = append(eventsToStore, events...)
return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey)
})
if err != nil {
return nil, err
}
@@ -312,30 +339,9 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
am.StoreEvent(ctx, userID, newKey.Id, accountID, activity.SetupKeyRevoked, newKey.EventMeta())
}
defer func() {
addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups)
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups)
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
}
}
for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
}
}
}()
for _, storeEvent := range eventsToStore {
storeEvent()
}
return newKey, nil
}
@@ -347,16 +353,15 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.NewUnauthorizedToViewSetupKeysError()
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return setupKeys, nil
return am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
}
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
@@ -366,11 +371,15 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.NewUnauthorizedToViewSetupKeysError()
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
if err != nil {
return nil, err
}
@@ -387,21 +396,29 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
return err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return status.NewUnauthorizedToViewSetupKeysError()
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var deletedSetupKey *SetupKey
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
if err != nil {
return err
}
return transaction.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID)
})
if err != nil {
return fmt.Errorf("failed to get setup key: %w", err)
}
err = am.Store.DeleteSetupKey(ctx, accountID, keyID)
if err != nil {
return fmt.Errorf("failed to delete setup key: %w", err)
return err
}
am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta())
@@ -409,15 +426,62 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
return nil
}
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
for _, group := range autoGroups {
g, ok := account.Groups[group]
func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error {
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs)
if err != nil {
return err
}
for _, groupID := range autoGroupIDs {
group, ok := groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group %s doesn't exist", group)
return status.Errorf(status.NotFound, "group not found: %s", groupID)
}
if g.Name == "All" {
return status.Errorf(status.InvalidArgument, "can't add All group to the setup key")
if group.IsGroupAll() {
return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key")
}
}
return nil
}
// prepareSetupKeyEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() {
var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
return nil
}
for _, g := range removedGroups {
group, ok := groups[g]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g)
continue
}
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta)
})
}
for _, g := range addedGroups {
group, ok := groups[g]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g)
continue
}
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta)
})
}
return eventsToStore
}

View File

@@ -56,11 +56,9 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
}
autoGroups := []string{"group_1", "group_2"}
newKeyName := "my-new-test-key"
revoked := true
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
Id: key.Id,
Name: newKeyName,
Revoked: revoked,
AutoGroups: autoGroups,
}, userID)
@@ -68,7 +66,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
t.Fatal(err)
}
assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
assertKey(t, newKey, keyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
key.Id, time.Now().UTC(), autoGroups, true)
// check the corresponding events that should have been generated
@@ -76,7 +74,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
assert.NotNil(t, ev)
assert.Equal(t, account.Id, ev.AccountID)
assert.Equal(t, newKeyName, ev.Meta["name"])
assert.Equal(t, keyName, ev.Meta["name"])
assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"]))
assert.NotEmpty(t, ev.Meta["key"])
assert.Equal(t, userID, ev.InitiatorID)
@@ -89,7 +87,6 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
autoGroups = append(autoGroups, groupAll.ID)
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
Id: key.Id,
Name: newKeyName,
Revoked: revoked,
AutoGroups: autoGroups,
}, userID)
@@ -213,22 +210,41 @@ func TestGetSetupKeys(t *testing.T) {
t.Fatal(err)
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "group_1",
Name: "group_name_1",
Peers: []string{},
})
plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
if err != nil {
t.Fatal(err)
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "group_2",
Name: "group_name_2",
Peers: []string{},
})
if err != nil {
t.Fatal(err)
type testCase struct {
name string
keyId string
expectedFailure bool
}
testCase1 := testCase{
name: "Should get existing Setup Key",
keyId: plainKey.Id,
expectedFailure: false,
}
testCase2 := testCase{
name: "Should fail to get non-existent Setup Key",
keyId: "some key",
expectedFailure: true,
}
for _, tCase := range []testCase{testCase1, testCase2} {
t.Run(tCase.name, func(t *testing.T) {
key, err := manager.GetSetupKey(context.Background(), account.Id, userID, tCase.keyId)
if tCase.expectedFailure {
if err == nil {
t.Fatal("expected to fail")
}
return
}
assert.NotEqual(t, plainKey.Key, key.Key)
})
}
}
@@ -449,3 +465,31 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
}
})
}
func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
if err != nil {
t.Fatal(err)
}
key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
assert.NoError(t, err)
// revoke the key
updateKey := key.Copy()
updateKey.Revoked = true
_, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID)
assert.NoError(t, err)
// re-activate revoked key
updateKey.Revoked = false
_, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID)
assert.Error(t, err, "should not allow to update revoked key")
}

File diff suppressed because it is too large Load Diff

View File

@@ -14,11 +14,10 @@ import (
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
route2 "github.com/netbirdio/netbird/route"
@@ -1181,7 +1180,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
t.Fatal("failed to save group")
return err
}
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID)
if err != nil {
t.Fatal("failed to get group")
return err
@@ -1201,7 +1200,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
users, err := store.GetAccountUsers(context.Background(), accountID)
users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
require.Len(t, users, len(account.Users))
}
@@ -1260,9 +1259,9 @@ func TestSqlite_GetGroupByName(t *testing.T) {
}
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID)
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
require.NoError(t, err)
require.Equal(t, "All", group.Name)
require.True(t, group.IsGroupAll())
}
func Test_DeleteSetupKeySuccessfully(t *testing.T) {
@@ -1274,7 +1273,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID)
err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, setupKeyID)
require.NoError(t, err)
_, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID)
@@ -1290,6 +1289,278 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
nonExistingKeyID := "non-existing-key-id"
err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID)
err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID)
require.Error(t, err)
}
func TestSqlStore_GetGroupsByIDs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
groupIDs []string
expectedCount int
}{
{
name: "retrieve existing groups by existing IDs",
groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
expectedCount: 2,
},
{
name: "empty group IDs list",
groupIDs: []string{},
expectedCount: 0,
},
{
name: "non-existing group IDs",
groupIDs: []string{"nonexistent1", "nonexistent2"},
expectedCount: 0,
},
{
name: "mixed existing and non-existing group IDs",
groupIDs: []string{"cfefqs706sqkneg59g4g", "nonexistent"},
expectedCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs)
require.NoError(t, err)
require.Len(t, groups, tt.expectedCount)
})
}
}
func TestSqlStore_SaveGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group := &nbgroup.Group{
ID: "group-id",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
}
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
require.NoError(t, err)
savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id")
require.NoError(t, err)
require.Equal(t, savedGroup, group)
}
func TestSqlStore_SaveGroups(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
groups := []*nbgroup.Group{
{
ID: "group-1",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
},
{
ID: "group-2",
AccountID: accountID,
Issued: "integration",
Peers: []string{"peer3", "peer4"},
},
}
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups)
require.NoError(t, err)
}
func TestSqlStore_DeleteGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
groupID string
expectError bool
}{
{
name: "delete existing group",
groupID: "cfefqs706sqkneg59g4g",
expectError: false,
},
{
name: "delete non-existing group",
groupID: "non-existing-group-id",
expectError: true,
},
{
name: "delete with empty group ID",
groupID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
} else {
require.NoError(t, err)
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID)
require.Error(t, err)
require.Nil(t, group)
}
})
}
}
func TestSqlStore_DeleteGroups(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
groupIDs []string
expectError bool
}{
{
name: "delete multiple existing groups",
groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
expectError: false,
},
{
name: "delete non-existing groups",
groupIDs: []string{"non-existing-id-1", "non-existing-id-2"},
expectError: false,
},
{
name: "delete with empty group IDs list",
groupIDs: []string{},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs)
if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
for _, groupID := range tt.groupIDs {
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
require.Error(t, err)
require.Nil(t, group)
}
}
})
}
}
func TestSqlStore_GetPeerByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
peerID string
expectError bool
}{
{
name: "retrieve existing peer",
peerID: "cfefqs706sqkneg59g4g",
expectError: false,
},
{
name: "retrieve non-existing peer",
peerID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty peer ID",
peerID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, peer)
} else {
require.NoError(t, err)
require.NotNil(t, peer)
require.Equal(t, tt.peerID, peer.ID)
}
})
}
}
func TestSqlStore_GetPeersByIDs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
peerIDs []string
expectedCount int
}{
{
name: "retrieve existing peers by existing IDs",
peerIDs: []string{"cfefqs706sqkneg59g4g", "cfeg6sf06sqkneg59g50"},
expectedCount: 2,
},
{
name: "empty peer IDs list",
peerIDs: []string{},
expectedCount: 0,
},
{
name: "non-existing peer IDs",
peerIDs: []string{"nonexistent1", "nonexistent2"},
expectedCount: 0,
},
{
name: "mixed existing and non-existing peer IDs",
peerIDs: []string{"cfeg6sf06sqkneg59g50", "nonexistent"},
expectedCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs)
require.NoError(t, err)
require.Len(t, peers, tt.expectedCount)
})
}
}

View File

@@ -3,7 +3,6 @@ package status
import (
"errors"
"fmt"
"time"
)
const (
@@ -103,22 +102,27 @@ func NewPeerLoginExpiredError() error {
}
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
func NewSetupKeyNotFoundError(err error) error {
return Errorf(NotFound, "setup key not found: %s", err)
func NewSetupKeyNotFoundError(setupKeyID string) error {
return Errorf(NotFound, "setup key: %s not found", setupKeyID)
}
func NewGetAccountFromStoreError(err error) error {
return Errorf(Internal, "issue getting account from store: %s", err)
}
// NewUserNotPartOfAccountError creates a new Error with PermissionDenied type for a user not being part of an account
func NewUserNotPartOfAccountError() error {
return Errorf(PermissionDenied, "user is not part of this account")
}
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
func NewGetUserFromStoreError() error {
return Errorf(Internal, "issue getting user from store")
}
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
func NewStoreContextCanceledError(duration time.Duration) error {
return Errorf(Internal, "store access: context canceled after %v", duration)
// NewAdminPermissionError creates a new Error with PermissionDenied type for actions requiring admin role.
func NewAdminPermissionError() error {
return Errorf(PermissionDenied, "admin role required to perform this action")
}
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
@@ -126,7 +130,12 @@ func NewInvalidKeyIDError() error {
return Errorf(InvalidArgument, "invalid key ID")
}
// NewUnauthorizedToViewSetupKeysError creates a new Error with Unauthorized type for an issue getting a setup key
func NewUnauthorizedToViewSetupKeysError() error {
return Errorf(Unauthorized, "only users with admin power can view setup keys")
// NewGetAccountError creates a new Error with Internal type for an issue getting account
func NewGetAccountError(err error) error {
return Errorf(Internal, "error getting account: %s", err)
}
// NewGroupNotFoundError creates a new Error with NotFound type for a missing group
func NewGroupNotFoundError(groupID string) error {
return Errorf(NotFound, "group: %s not found", groupID)
}

View File

@@ -62,7 +62,7 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
@@ -70,11 +70,14 @@ type Store interface {
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
@@ -89,6 +92,8 @@ type Store interface {
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
@@ -96,7 +101,9 @@ type Store interface {
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error)
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
@@ -105,7 +112,7 @@ type Store interface {
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, accountId string) error
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
GetInstallationID() string
@@ -124,7 +131,6 @@ type Store interface {
// This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
DeleteSetupKey(ctx context.Context, accountID, keyID string) error
}
type StoreEngine string

View File

@@ -13,6 +13,7 @@ type AccountManagerMetrics struct {
updateAccountPeersDurationMs metric.Float64Histogram
getPeerNetworkMapDurationMs metric.Float64Histogram
networkMapObjectCount metric.Int64Histogram
peerMetaUpdateCount metric.Int64Counter
}
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
@@ -44,11 +45,17 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
return nil, err
}
peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter", metric.WithUnit("1"))
if err != nil {
return nil, err
}
return &AccountManagerMetrics{
ctx: ctx,
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
updateAccountPeersDurationMs: updateAccountPeersDurationMs,
networkMapObjectCount: networkMapObjectCount,
peerMetaUpdateCount: peerMetaUpdateCount,
}, nil
}
@@ -67,3 +74,8 @@ func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration ti
func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
metrics.networkMapObjectCount.Record(metrics.ctx, count)
}
// CountPeerMetUpdate counts the number of peer meta updates
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)
}

View File

@@ -13,6 +13,7 @@ type StoreMetrics struct {
globalLockAcquisitionDurationMs metric.Int64Histogram
persistenceDurationMicro metric.Int64Histogram
persistenceDurationMs metric.Int64Histogram
transactionDurationMs metric.Int64Histogram
ctx context.Context
}
@@ -40,11 +41,17 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er
return nil, err
}
transactionDurationMs, err := meter.Int64Histogram("management.store.transaction.duration.ms")
if err != nil {
return nil, err
}
return &StoreMetrics{
globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro,
globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs,
persistenceDurationMicro: persistenceDurationMicro,
persistenceDurationMs: persistenceDurationMs,
transactionDurationMs: transactionDurationMs,
ctx: ctx,
}, nil
}
@@ -60,3 +67,8 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) {
metrics.persistenceDurationMicro.Record(metrics.ctx, duration.Microseconds())
metrics.persistenceDurationMs.Record(metrics.ctx, duration.Milliseconds())
}
// CountTransactionDuration counts the duration of a store persistence operation
func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) {
metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds())
}

View File

@@ -96,9 +96,12 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
if channel, ok := p.peerChannels[peerID]; ok {
delete(p.peerChannels, peerID)
close(channel)
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
return
}
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
log.WithContext(ctx).Debugf("closing updates channel: peer %s has no channel", peerID)
}
// CloseChannels closes updates channel for each given peer

View File

@@ -9,14 +9,16 @@ import (
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
)
const (
@@ -103,6 +105,11 @@ func (u *User) IsAdminOrServiceUser() bool {
return u.HasAdminPower() || u.IsServiceUser
}
// IsRegularUser checks if the user is a regular user.
func (u *User) IsRegularUser() bool {
return !u.HasAdminPower() && !u.IsServiceUser
}
// ToUserInfo converts a User object to a UserInfo object.
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
autoGroups := u.AutoGroups
@@ -487,7 +494,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account.Id)
}
return nil
@@ -798,15 +805,20 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
expiredPeers = append(expiredPeers, blockedPeers...)
}
peerGroupsAdded := make(map[string][]string)
peerGroupsRemoved := make(map[string][]string)
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
// need force update all auto groups in any case they will not be duplicated
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
peerGroupsAdded = account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
peerGroupsRemoved = account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
}
events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
eventsToStore = append(eventsToStore, events...)
userUpdateEvents := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
eventsToStore = append(eventsToStore, userUpdateEvents...)
userGroupsEvents := am.prepareUserGroupsEvents(ctx, initiatorUser.Id, oldUser, newUser, account, peerGroupsAdded, peerGroupsRemoved)
eventsToStore = append(eventsToStore, userGroupsEvents...)
updatedUserInfo, err := getUserInfo(ctx, am, newUser, account)
if err != nil {
@@ -828,7 +840,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account.Id)
}
for _, storeEvent := range eventsToStore {
@@ -865,32 +877,78 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in
})
}
return eventsToStore
}
func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() {
var eventsToStore []func()
if newUser.AutoGroups != nil {
removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups)
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
}
}
for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
})
}
removedEvents := am.handleGroupRemovedFromUser(ctx, initiatorUserID, oldUser, newUser, account, removedGroups, peerGroupsRemoved)
eventsToStore = append(eventsToStore, removedEvents...)
addedEvents := am.handleGroupAddedToUser(ctx, initiatorUserID, oldUser, newUser, account, addedGroups, peerGroupsAdded)
eventsToStore = append(eventsToStore, addedEvents...)
}
return eventsToStore
}
func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() {
var eventsToStore []func()
for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
})
}
}
for groupID, peerIDs := range peerGroupsAdded {
group := account.GetGroup(groupID)
for _, peerID := range peerIDs {
peer := account.GetPeer(peerID)
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
}
am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupAddedToPeer, meta)
})
}
}
return eventsToStore
}
func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() {
var eventsToStore []func()
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
}
}
for groupID, peerIDs := range peerGroupsRemoved {
group := account.GetGroup(groupID)
for _, peerID := range peerIDs {
peer := account.GetPeer(peerID)
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
}
am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupRemovedFromPeer, meta)
})
}
}
return eventsToStore
}
@@ -1100,6 +1158,9 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error {
var peerIDs []string
for _, peer := range peers {
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key)
if peer.Status.LoginExpired {
continue
}
@@ -1107,8 +1168,11 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
peer.MarkLoginExpired(true)
account.UpdatePeer(peer)
if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil {
return err
return fmt.Errorf("failed saving peer status for peer %s: %s", peer.ID, err)
}
log.WithContext(ctx).Tracef("mark peer %s login expired", peer.ID)
am.StoreEvent(
ctx,
peer.UserID, peer.ID, account.Id,
@@ -1119,7 +1183,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
am.peersUpdateManager.CloseChannels(ctx, peerIDs)
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account.Id)
}
return nil
}
@@ -1227,7 +1291,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
}
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
for targetUserID, meta := range deletedUsersMeta {

View File

@@ -3,7 +3,6 @@ package client
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
@@ -449,11 +448,11 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
conn, ok := c.conns[id]
c.mu.Unlock()
if !ok {
return 0, io.EOF
return 0, net.ErrClosed
}
if conn.conn != connReference {
return 0, io.EOF
return 0, net.ErrClosed
}
// todo: use buffer pool instead of create new transport msg.
@@ -508,7 +507,7 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
container, ok := c.conns[id]
if !ok {
return fmt.Errorf("connection already closed")
return net.ErrClosed
}
if container.conn != connReference {

View File

@@ -1,7 +1,6 @@
package client
import (
"io"
"net"
"time"
)
@@ -40,7 +39,7 @@ func (c *Conn) Write(p []byte) (n int, err error) {
func (c *Conn) Read(b []byte) (n int, err error) {
msg, ok := <-c.messageChan
if !ok {
return 0, io.EOF
return 0, net.ErrClosed
}
n = copy(b, msg.Payload)

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
@@ -100,7 +99,7 @@ func (c *Conn) isClosed() bool {
func (c *Conn) ioErrHandling(err error) error {
if c.isClosed() {
return io.EOF
return net.ErrClosed
}
var wErr *websocket.CloseError
@@ -108,7 +107,7 @@ func (c *Conn) ioErrHandling(err error) error {
return err
}
if wErr.Code == websocket.StatusNormalClosure {
return io.EOF
return net.ErrClosed
}
return err
}

View File

@@ -63,13 +63,14 @@ func (l *Listener) Shutdown(ctx context.Context) error {
}
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
connRemoteAddr := remoteAddr(r)
wsConn, err := websocket.Accept(w, r, nil)
if err != nil {
log.Errorf("failed to accept ws connection from %s: %s", r.RemoteAddr, err)
log.Errorf("failed to accept ws connection from %s: %s", connRemoteAddr, err)
return
}
rAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
rAddr, err := net.ResolveTCPAddr("tcp", connRemoteAddr)
if err != nil {
err = wsConn.Close(websocket.StatusInternalError, "internal error")
if err != nil {
@@ -90,3 +91,10 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
conn := NewConn(wsConn, lAddr, rAddr)
l.acceptFn(conn)
}
func remoteAddr(r *http.Request) string {
if r.Header.Get("X-Real-Ip") == "" || r.Header.Get("X-Real-Port") == "" {
return r.RemoteAddr
}
return fmt.Sprintf("%s:%s", r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port"))
}

View File

@@ -2,7 +2,7 @@ package server
import (
"context"
"io"
"errors"
"net"
"sync"
"time"
@@ -16,6 +16,8 @@ import (
const (
bufferSize = 8820
errCloseConn = "failed to close connection to peer: %s"
)
// Peer represents a peer connection
@@ -46,6 +48,12 @@ func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
// the message accordingly.
func (p *Peer) Work() {
defer func() {
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
p.log.Errorf(errCloseConn, err)
}
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -57,7 +65,7 @@ func (p *Peer) Work() {
for {
n, err := p.conn.Read(buf)
if err != nil {
if err != io.EOF {
if !errors.Is(err, net.ErrClosed) {
p.log.Errorf("failed to read message: %s", err)
}
return
@@ -97,7 +105,7 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
case messages.MsgTypeClose:
p.log.Infof("peer exited gracefully")
if err := p.conn.Close(); err != nil {
log.Errorf("failed to close connection to peer: %s", err)
log.Errorf(errCloseConn, err)
}
default:
p.log.Warnf("received unexpected message type: %s", msgType)
@@ -121,9 +129,8 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
p.log.Errorf("failed to send close message to peer: %s", p.String())
}
err = p.conn.Close()
if err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
if err := p.conn.Close(); err != nil {
p.log.Errorf(errCloseConn, err)
}
}
@@ -132,7 +139,7 @@ func (p *Peer) Close() {
defer p.connMu.Unlock()
if err := p.conn.Close(); err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
p.log.Errorf(errCloseConn, err)
}
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
@@ -14,8 +15,21 @@ import (
log "github.com/sirupsen/logrus"
)
func WriteBytesWithRestrictedPermission(ctx context.Context, file string, bs []byte) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return fmt.Errorf("prepare config file dir: %w", err)
}
if err = EnforcePermission(file); err != nil {
return fmt.Errorf("enforce permission: %w", err)
}
return writeBytes(ctx, file, err, configDir, configFileName, bs)
}
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return err
@@ -26,18 +40,18 @@ func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
return err
}
return writeJson(file, obj, configDir, configFileName)
return writeJson(ctx, file, obj, configDir, configFileName)
}
// WriteJson writes JSON config object to a file creating parent directories if required
// The output JSON is pretty-formatted
func WriteJson(file string, obj interface{}) error {
func WriteJson(ctx context.Context, file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return err
}
return writeJson(file, obj, configDir, configFileName)
return writeJson(ctx, file, obj, configDir, configFileName)
}
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
@@ -79,24 +93,47 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
return nil
}
func writeJson(file string, obj interface{}, configDir string, configFileName string) error {
func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error {
// Check context before expensive operations
if ctx.Err() != nil {
return fmt.Errorf("write json start: %w", ctx.Err())
}
// make it pretty
bs, err := json.MarshalIndent(obj, "", " ")
if err != nil {
return err
return fmt.Errorf("marshal: %w", err)
}
return writeBytes(ctx, file, err, configDir, configFileName, bs)
}
func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error {
if ctx.Err() != nil {
return fmt.Errorf("write bytes start: %w", ctx.Err())
}
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
if err != nil {
return err
return fmt.Errorf("create temp: %w", err)
}
tempFileName := tempFile.Name()
// closing file ops as windows doesn't allow to move it
err = tempFile.Close()
if deadline, ok := ctx.Deadline(); ok {
if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) {
log.Warnf("failed to set deadline: %v", err)
}
}
_, err = tempFile.Write(bs)
if err != nil {
return err
_ = tempFile.Close()
return fmt.Errorf("write: %w", err)
}
if err = tempFile.Close(); err != nil {
return fmt.Errorf("close %s: %w", tempFileName, err)
}
defer func() {
@@ -106,14 +143,13 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st
}
}()
err = os.WriteFile(tempFileName, bs, 0600)
if err != nil {
return err
// Check context again
if ctx.Err() != nil {
return fmt.Errorf("after temp file: %w", ctx.Err())
}
err = os.Rename(tempFileName, file)
if err != nil {
return err
if err = os.Rename(tempFileName, file); err != nil {
return fmt.Errorf("move %s to %s: %w", tempFileName, file, err)
}
return nil

View File

@@ -1,6 +1,7 @@
package util
import (
"context"
"crypto/md5"
"encoding/hex"
"io"
@@ -39,7 +40,7 @@ func TestConfigJSON(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
err := WriteJson(tmpDir+"/testconfig.json", tt.config)
err := WriteJson(context.Background(), tmpDir+"/testconfig.json", tt.config)
require.NoError(t, err)
read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
@@ -73,7 +74,7 @@ func TestCopyFileContents(t *testing.T) {
src := tmpDir + "/copytest_src"
dst := tmpDir + "/copytest_dst"
err := WriteJson(src, tt.srcContent)
err := WriteJson(context.Background(), src, tt.srcContent)
require.NoError(t, err)
err = CopyFileContents(src, dst)
@@ -127,7 +128,7 @@ func TestHandleConfigFileWithoutFullPath(t *testing.T) {
_ = os.Remove(cfgFile)
}()
err := WriteJson(cfgFile, tt.config)
err := WriteJson(context.Background(), cfgFile, tt.config)
require.NoError(t, err)
read, err := ReadJson(cfgFile, &TestConfig{})

View File

@@ -3,6 +3,9 @@ package grpc
import (
"context"
"crypto/tls"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net"
"os/user"
"runtime"
@@ -23,20 +26,22 @@ func WithCustomDialer() grpc.DialOption {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
log.Fatalf("failed to get current user: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
log.Debug("Using nbnet.NewDialer()")
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, err
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
})

View File

@@ -69,7 +69,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("dial: %w", err)
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
}
// Wrap the connection in Conn to handle Close with hooks

View File

@@ -11,8 +11,11 @@ import (
const (
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
NetbirdFwmark = 0x1BD00
PreroutingFwmark = 0x1BD01
NetbirdFwmark = 0x1BD00
PreroutingFwmarkRedirected = 0x1BD01
PreroutingFwmarkMasquerade = 0x1BD11
PreroutingFwmarkMasqueradeReturn = 0x1BD12
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
)

View File

@@ -4,9 +4,14 @@ package net
import (
"fmt"
"os"
"syscall"
log "github.com/sirupsen/logrus"
)
const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK"
// SetSocketMark sets the SO_MARK option on the given socket connection
func SetSocketMark(conn syscall.Conn) error {
sysconn, err := conn.SyscallConn()
@@ -36,6 +41,13 @@ func SetRawSocketMark(conn syscall.RawConn) error {
func SetSocketOpt(fd int) error {
if CustomRoutingDisabled() {
log.Infof("Custom routing is disabled, skipping SO_MARK")
return nil
}
// Check for the new environment variable
if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" {
log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK")
return nil
}