Compare commits

...

43 Commits

Author SHA1 Message Date
Viktor Liu
3efa7a282a Set log level debug 2024-11-27 13:46:37 +01:00
Viktor Liu
40551099b3 Add debug 2024-11-27 13:41:32 +01:00
Pascal Fischer
9db1932664 [management] Fix getSetupKey call (#2927) 2024-11-22 10:15:51 +01:00
Viktor Liu
1bbabf70b0 [client] Fix allow netbird rule verdict (#2925)
* Fix allow netbird rule verdict

* Fix chain name
2024-11-21 16:53:37 +01:00
Pascal Fischer
aa575d6f44 [management] Add activity events to group propagation flow (#2916) 2024-11-21 15:10:34 +01:00
Pascal Fischer
f66bbcc54c [management] Add metric for peer meta update (#2913) 2024-11-19 18:13:26 +01:00
Pascal Fischer
5dd6a08ea6 link peer meta update back to account object (#2911) 2024-11-19 17:25:49 +01:00
Krzysztof Nazarewski (kdn)
eb5d0569ae [client] Add NB_SKIP_SOCKET_MARK & fix crash instead of returing an error (#2899)
* dialer: fix crash instead of returning error

* add NB_SKIP_SOCKET_MARK
2024-11-19 14:14:58 +01:00
Pascal Fischer
52ea2e84e9 [management] Add transaction metrics and exclude getAccount time from peers update (#2904) 2024-11-19 00:04:50 +01:00
Maycon Santos
78fab877c0 [misc] Update signing pipeline version (#2900) 2024-11-18 15:31:53 +01:00
Maycon Santos
65a94f695f use google domain for tests (#2902) 2024-11-18 12:55:02 +01:00
Kursat Aktas
ec543f89fb Introducing NetBird Guru on Gurubase.io (#2778) 2024-11-16 15:45:31 +01:00
Viktor Liu
a7d5c52203 Fix error state race on mgmt connection error (#2892) 2024-11-15 22:59:49 +01:00
Viktor Liu
582bb58714 Move state updates outside the refcounter (#2897) 2024-11-15 22:55:33 +01:00
Viktor Liu
121dfda915 [client] Fix state manager race conditions (#2890) 2024-11-15 20:05:26 +01:00
İsmail
a1c5287b7c Fix the Inactivity Expiration problem. (#2865) 2024-11-15 18:21:27 +01:00
Bethuel Mmbaga
12f442439a [management] Refactor group to use store methods (#2867)
* Refactor setup key handling to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add lock to get account groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add check for regular user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get only required groups for auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add account lock and return auto groups map on validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor account peers update

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor groups to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor GetGroupByID and add NewGroupNotFoundError

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add AddPeer and RemovePeer methods to Group struct

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Preserve store engine in SqlStore transactions

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run groups ops in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix missing group removed from setup key activity

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Change setup key log level to debug for missing group

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve modified peers once for group events

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add account locking and merge group deletion methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-15 20:09:32 +03:00
Pascal Fischer
d9b691b8a5 [management] Limit the setup-key update operation (#2841) 2024-11-15 17:00:06 +01:00
Pascal Fischer
4aee3c9e33 [client/management] add peer lock to peer meta update and fix isEqual func (#2840) 2024-11-15 16:59:03 +01:00
Pascal Fischer
44e799c687 [management] Fix limited peer view groups (#2894) 2024-11-15 11:16:16 +01:00
Viktor Liu
be78efbd42 [client] Handle panic on nil wg interface (#2891) 2024-11-14 20:15:16 +01:00
Maycon Santos
6886691213 Update route calculation tests (#2884)
- Add two new test cases for p2p and relay routes with same latency
- Add extra statuses generation
2024-11-13 15:21:33 +01:00
Zoltan Papp
b48afd92fd [relay-server] Always close ws conn when work thread exit (#2879)
Close ws conn when work thread exit
2024-11-13 15:02:51 +01:00
Viktor Liu
39329e12a1 [client] Improve state write timeout and abort work early on timeout (#2882)
* Improve state write timeout and abort work early on timeout

* Don't block on initial persist state
2024-11-13 13:46:00 +01:00
Pascal Fischer
20a5afc359 [management] Add more logs to the peer update processes (#2881) 2024-11-12 14:19:22 +01:00
Bethuel Mmbaga
6cb697eed6 [management] Refactor setup key to use store methods (#2861)
* Refactor setup key handling to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add lock to get account groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add check for regular user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get only required groups for auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add account lock and return auto groups map on validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix missing group removed from setup key activity

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Remove context from DB queries

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add user permission check and add setup events into events to store slice

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve all groups once during setup key auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-11 19:46:10 +03:00
Viktor Liu
e0bed2b0fb [client] Fix race conditions (#2869)
* Fix concurrent map access in status

* Fix race when retrieving ctx state error

* Fix race when accessing service controller server instance
2024-11-11 14:55:10 +01:00
Zoltan Papp
30f025e7dd [client] fix/proxy close (#2873)
When the remote peer switches the Relay instance then must to close the proxy connection to the old instance.

It can cause issues when the remote peer switch connects to the Relay instance multiple times and then reconnects to an instance it had previously connected to.
2024-11-11 14:18:38 +01:00
Zoltan Papp
b4d7605147 [client] Remove loop after route calculation (#2856)
- ICE do not trigger disconnect callbacks if the stated did not change
- Fix route calculation callback loop
- Move route state updates into protected scope by mutex
- Do not calculate routes in case of peer.Open() and peer.Close()
2024-11-11 10:53:57 +01:00
Viktor Liu
08b6e9d647 [management] Fix api error message typo peers_group (#2862) 2024-11-08 23:28:02 +01:00
Pascal Fischer
67ce14eaea [management] Add peer lock to grpc server (#2859)
* add peer lock to grpc server

* remove sleep and put db update first

* don't export lock method
2024-11-08 18:47:22 +01:00
Pascal Fischer
669904cd06 [management] Remove context from database calls (#2863) 2024-11-08 15:49:00 +01:00
Zoltan Papp
4be826450b [client] Use offload in WireGuard bind receiver (#2815)
Improve the performance on Linux and Android in case of P2P connections
2024-11-07 17:28:38 +01:00
Maycon Santos
738387f2de Add benchmark tests to get account with claims (#2761)
* Add benchmark tests to get account with claims

* add users to account objects

* remove hardcoded env
2024-11-07 17:23:35 +01:00
Pascal Fischer
baf0678ceb [management] Fix potential panic on inactivity expiration log message (#2854) 2024-11-07 16:33:57 +01:00
Pascal Fischer
7fef8f6758 [management] Enforce max conn of 1 for sqlite setups (#2855) 2024-11-07 16:32:35 +01:00
Viktor Liu
6829a64a2d [client] Exclude split default route ip addresses from anonymization (#2853) 2024-11-07 16:29:32 +01:00
Zoltan Papp
cbf500024f [relay-server] Use X-Real-IP in case of reverse proxy (#2848)
* Use X-Real-IP in case of reverse proxy

* Use sprintf
2024-11-07 16:14:53 +01:00
Viktor Liu
509e184e10 [client] Use the prerouting chain to mark for masquerading to support older systems (#2808) 2024-11-07 12:37:04 +01:00
Pascal Fischer
3e88b7c56e [management] Fix network map update on peer validation (#2849) 2024-11-07 09:50:13 +01:00
Maycon Santos
b952d8693d Fix cached device flow oauth (#2833)
This change removes the cached device flow oauth info when a down command is called

Removing the need for the agent to be restarted
2024-11-05 14:51:17 +01:00
Maycon Santos
5b46cc8e9c Avoid failing all other matrix tests if one fails (#2839) 2024-11-05 13:28:42 +01:00
Pascal Fischer
a9d06b883f add all group to add peer affected peers network map check (#2830) 2024-11-01 22:09:08 +01:00
88 changed files with 2957 additions and 1418 deletions

View File

@@ -13,6 +13,7 @@ concurrency:
jobs: jobs:
test: test:
strategy: strategy:
fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres'] store: [ 'sqlite', 'postgres']

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.16" SIGN_PIPE_VER: "v0.0.17"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"

View File

@@ -20,6 +20,10 @@
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ"> <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a> </a>
<br>
<a href="https://gurubase.io/g/netbird">
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
</a>
</p> </p>
</div> </div>

View File

@@ -201,6 +201,8 @@ func isWellKnown(addr netip.Addr) bool {
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6 "2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4 "9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6 "2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
"128.0.0.0", "8000::", // 2nd split subnet for default routes
} }
if slices.Contains(wellKnown, addr.String()) { if slices.Contains(wellKnown, addr.String()) {

View File

@@ -2,6 +2,7 @@ package cmd
import ( import (
"context" "context"
"sync"
"github.com/kardianos/service" "github.com/kardianos/service"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -17,6 +18,7 @@ type program struct {
cancel context.CancelFunc cancel context.CancelFunc
serv *grpc.Server serv *grpc.Server
serverInstance *server.Server serverInstance *server.Server
serverInstanceMu sync.Mutex
} }
func newProgram(ctx context.Context, cancel context.CancelFunc) *program { func newProgram(ctx context.Context, cancel context.CancelFunc) *program {

View File

@@ -61,7 +61,9 @@ func (p *program) Start(svc service.Service) error {
} }
proto.RegisterDaemonServiceServer(p.serv, serverInstance) proto.RegisterDaemonServiceServer(p.serv, serverInstance)
p.serverInstanceMu.Lock()
p.serverInstance = serverInstance p.serverInstance = serverInstance
p.serverInstanceMu.Unlock()
log.Printf("started daemon server: %v", split[1]) log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil { if err := p.serv.Serve(listen); err != nil {
@@ -72,6 +74,7 @@ func (p *program) Start(svc service.Service) error {
} }
func (p *program) Stop(srv service.Service) error { func (p *program) Stop(srv service.Service) error {
p.serverInstanceMu.Lock()
if p.serverInstance != nil { if p.serverInstance != nil {
in := new(proto.DownRequest) in := new(proto.DownRequest)
_, err := p.serverInstance.Down(p.ctx, in) _, err := p.serverInstance.Down(p.ctx, in)
@@ -79,6 +82,7 @@ func (p *program) Stop(srv service.Service) error {
log.Errorf("failed to stop daemon: %v", err) log.Errorf("failed to stop daemon: %v", err)
} }
} }
p.serverInstanceMu.Unlock()
p.cancel() p.cancel()

View File

@@ -352,14 +352,14 @@ func (m *aclManager) seedInitialEntries() {
func (m *aclManager) seedInitialOptionalEntries() { func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{ m.optionalEntries["FORWARD"] = []entry{
{ {
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules}, spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
position: 2, position: 2,
}, },
} }
m.optionalEntries["PREROUTING"] = []entry{ m.optionalEntries["PREROUTING"] = []entry{
{ {
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)}, spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
position: 1, position: 1,
}, },
} }

View File

@@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
} }
// persist early to ensure cleanup of chains // persist early to ensure cleanup of chains
go func() {
if err := stateManager.PersistState(context.Background()); err != nil { if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
}()
return nil return nil
} }

View File

@@ -18,22 +18,24 @@ import (
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) nbnet "github.com/netbirdio/netbird/util/net"
const (
ipv4Nat = "netbird-rt-nat"
) )
// constants needed to manage and create iptable rules // constants needed to manage and create iptable rules
const ( const (
tableFilter = "filter" tableFilter = "filter"
tableNat = "nat" tableNat = "nat"
tableMangle = "mangle"
chainPOSTROUTING = "POSTROUTING" chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainRTNAT = "NETBIRD-RT-NAT" chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD" chainRTFWD = "NETBIRD-RT-FWD"
chainRTPRE = "NETBIRD-RT-PRE"
routingFinalForwardJump = "ACCEPT" routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE" routingFinalNatJump = "MASQUERADE"
jumpPre = "jump-pre"
jumpNat = "jump-nat"
matchSet = "--match-set" matchSet = "--match-set"
) )
@@ -323,24 +325,25 @@ func (r *router) Reset() error {
} }
func (r *router) cleanUpDefaultForwardRules() error { func (r *router) cleanUpDefaultForwardRules() error {
err := r.cleanJumpRules() if err := r.cleanJumpRules(); err != nil {
if err != nil { return fmt.Errorf("clean jump rules: %w", err)
return err
} }
log.Debug("flushing routing related tables") log.Debug("flushing routing related tables")
for _, chain := range []string{chainRTFWD, chainRTNAT} { for _, chainInfo := range []struct {
table := r.getTableForChain(chain) chain string
table string
ok, err := r.iptablesClient.ChainExists(table, chain) }{
{chainRTFWD, tableFilter},
{chainRTNAT, tableNat},
{chainRTPRE, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil { if err != nil {
log.Errorf("failed check chain %s, error: %v", chain, err) return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
return err
} else if ok { } else if ok {
err = r.iptablesClient.ClearAndDeleteChain(table, chain) if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
if err != nil { return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
return err
} }
} }
} }
@@ -349,9 +352,16 @@ func (r *router) cleanUpDefaultForwardRules() error {
} }
func (r *router) createContainers() error { func (r *router) createContainers() error {
for _, chain := range []string{chainRTFWD, chainRTNAT} { for _, chainInfo := range []struct {
if err := r.createAndSetupChain(chain); err != nil { chain string
return fmt.Errorf("create chain %s: %w", chain, err) table string
}{
{chainRTFWD, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
} {
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
} }
} }
@@ -359,6 +369,10 @@ func (r *router) createContainers() error {
return fmt.Errorf("insert established rule: %w", err) return fmt.Errorf("insert established rule: %w", err)
} }
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add static nat rules: %w", err)
}
if err := r.addJumpRules(); err != nil { if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err) return fmt.Errorf("add jump rules: %w", err)
} }
@@ -366,6 +380,32 @@ func (r *router) createContainers() error {
return nil return nil
} }
func (r *router) addPostroutingRules() error {
// First rule for outbound masquerade
rule1 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
"!", "-o", "lo",
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
return fmt.Errorf("add outbound masquerade rule: %v", err)
}
r.rules["static-nat-outbound"] = rule1
// Second rule for return traffic masquerade
rule2 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
"-o", r.wgIface.Name(),
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
return fmt.Errorf("add return masquerade rule: %v", err)
}
r.rules["static-nat-return"] = rule2
return nil
}
func (r *router) createAndSetupChain(chain string) error { func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain) table := r.getTableForChain(chain)
@@ -377,10 +417,14 @@ func (r *router) createAndSetupChain(chain string) error {
} }
func (r *router) getTableForChain(chain string) string { func (r *router) getTableForChain(chain string) string {
if chain == chainRTNAT { switch chain {
case chainRTNAT:
return tableNat return tableNat
} case chainRTPRE:
return tableMangle
default:
return tableFilter return tableFilter
}
} }
func (r *router) insertEstablishedRule(chain string) error { func (r *router) insertEstablishedRule(chain string) error {
@@ -398,25 +442,39 @@ func (r *router) insertEstablishedRule(chain string) error {
} }
func (r *router) addJumpRules() error { func (r *router) addJumpRules() error {
rule := []string{"-j", chainRTNAT} // Jump to NAT chain
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...) natRule := []string{"-j", chainRTNAT}
if err != nil { if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return err return fmt.Errorf("add nat jump rule: %v", err)
} }
r.rules[ipv4Nat] = rule r.rules[jumpNat] = natRule
// Jump to prerouting chain
preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add prerouting jump rule: %v", err)
}
r.rules[jumpPre] = preRule
return nil return nil
} }
func (r *router) cleanJumpRules() error { func (r *router) cleanJumpRules() error {
rule, found := r.rules[ipv4Nat] for _, ruleKey := range []string{jumpNat, jumpPre} {
if found { if rule, exists := r.rules[ruleKey]; exists {
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...) table := tableNat
if err != nil { chain := chainPOSTROUTING
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err) if ruleKey == jumpPre {
} table = tableMangle
chain = chainPREROUTING
} }
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err)
}
delete(r.rules, ruleKey)
}
}
return nil return nil
} }
@@ -424,19 +482,35 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair) ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err) return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} }
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse) markValue := nbnet.PreroutingFwmarkMasquerade
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil { if pair.Inverse {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err) markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
rule := []string{"-i", r.wgIface.Name()}
if pair.Inverse {
rule = []string{"!", "-i", r.wgIface.Name()}
}
rule = append(rule,
"-m", "conntrack",
"--ctstate", "NEW",
"-s", pair.Source.String(),
"-d", pair.Destination.String(),
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
)
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
} }
r.rules[ruleKey] = rule r.rules[ruleKey] = rule
return nil return nil
} }
@@ -444,13 +518,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair) ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err) return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} else { } else {
log.Debugf("nat rule %s not found", ruleKey) log.Debugf("marking rule %s not found", ruleKey)
} }
return nil return nil
@@ -482,16 +555,6 @@ func (r *router) updateState() {
} }
} }
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
intdir := "-i"
lointdir := "-o"
if inverse {
intdir = "-o"
lointdir = "-i"
}
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string var rule []string

View File

@@ -3,17 +3,18 @@
package iptables package iptables
import ( import (
"fmt"
"net/netip" "net/netip"
"os/exec" "os/exec"
"testing" "testing"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test" "github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func isIptablesSupported() bool { func isIptablesSupported() bool {
@@ -34,14 +35,24 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.NoError(t, manager.init(nil)) require.NoError(t, manager.init(nil))
defer func() { defer func() {
_ = manager.Reset() assert.NoError(t, manager.Reset(), "shouldn't return error")
}() }()
require.Len(t, manager.rules, 2, "should have created rules map") // Now 5 rules:
// 1. established rule in forward chain
// 2. jump rule to NAT chain
// 3. jump rule to PRE chain
// 4. static outbound masquerade rule
// 5. static return masquerade rule
require.Len(t, manager.rules, 5, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...) exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting rule should exist") require.True(t, exists, "postrouting jump rule should exist")
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
require.True(t, exists, "prerouting jump rule should exist")
pair := firewall.RouterPair{ pair := firewall.RouterPair{
ID: "abc", ID: "abc",
@@ -49,22 +60,15 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
Destination: netip.MustParsePrefix("100.100.100.0/24"), Destination: netip.MustParsePrefix("100.100.100.0/24"),
Masquerade: true, Masquerade: true,
} }
forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...) err = manager.AddNatRule(pair)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "adding NAT rule should not return error")
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset() err = manager.Reset()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
} }
func TestIptablesManager_AddNatRule(t *testing.T) { func TestIptablesManager_AddNatRule(t *testing.T) {
if !isIptablesSupported() { if !isIptablesSupported() {
t.SkipNow() t.SkipNow()
} }
@@ -79,52 +83,66 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
require.NoError(t, manager.init(nil)) require.NoError(t, manager.init(nil))
defer func() { defer func() {
err := manager.Reset() assert.NoError(t, manager.Reset(), "shouldn't return error")
if err != nil {
log.Errorf("failed to reset iptables manager: %s", err)
}
}() }()
err = manager.AddNatRule(testCase.InputPair) err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "forwarding pair should be inserted") require.NoError(t, err, "marking rule should be inserted")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) markingRule := []string{
"-i", ifaceMock.Name(),
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) "-m", "conntrack",
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) "--ctstate", "NEW",
if testCase.InputPair.Masquerade { "-s", testCase.InputPair.Source.String(),
require.True(t, exists, "nat rule should be created") "-d", testCase.InputPair.Destination.String(),
foundNatRule, foundNat := manager.rules[natRuleKey] "-j", "MARK", "--set-mark",
require.True(t, foundNat, "nat rule should exist in the map") fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[natRuleKey]
require.False(t, foundNat, "nat rule should not exist in the map")
} }
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade { if testCase.InputPair.Masquerade {
require.True(t, exists, "income nat rule should be created") require.True(t, exists, "marking rule should be created")
foundNatRule, foundNat := manager.rules[inNatRuleKey] foundRule, found := manager.rules[natRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map") require.True(t, found, "marking rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match") require.Equal(t, markingRule, foundRule, "stored marking rule should match")
} else { } else {
require.False(t, exists, "nat rule should not be created") require.False(t, exists, "marking rule should not be created")
_, foundNat := manager.rules[inNatRuleKey] _, found := manager.rules[natRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map") require.False(t, found, "marking rule should not exist in the map")
}
// Check inverse rule
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
if testCase.InputPair.Masquerade {
require.True(t, exists, "inverse marking rule should be created")
foundRule, found := manager.rules[inverseRuleKey]
require.True(t, found, "inverse marking rule should exist in the map")
require.Equal(t, inverseMarkingRule, foundRule, "stored inverse marking rule should match")
} else {
require.False(t, exists, "inverse marking rule should not be created")
_, found := manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
} }
}) })
} }
} }
func TestIptablesManager_RemoveNatRule(t *testing.T) { func TestIptablesManager_RemoveNatRule(t *testing.T) {
if !isIptablesSupported() { if !isIptablesSupported() {
t.SkipNow() t.SkipNow()
} }
@@ -137,42 +155,52 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil)) require.NoError(t, manager.init(nil))
defer func() { defer func() {
_ = manager.Reset() assert.NoError(t, manager.Reset(), "shouldn't return error")
}() }()
require.NoError(t, err, "shouldn't return error") err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule without error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveNatRule(testCase.InputPair) err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) markingRule := []string{
require.False(t, exists, "nat rule should not exist") "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", testCase.InputPair.Source.String(),
"-d", testCase.InputPair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "marking rule should not exist")
_, found := manager.rules[natRuleKey] _, found := manager.rules[natRuleKey]
require.False(t, found, "nat rule should exist in the manager map") require.False(t, found, "marking rule should not exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) // Check inverse rule removal
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) inversePair := firewall.GetInversePair(testCase.InputPair)
require.False(t, exists, "income nat rule should not exist") inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
_, found = manager.rules[inNatRuleKey] exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.False(t, found, "income nat rule should exist in the manager map") require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "inverse marking rule should not exist")
_, found = manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
}) })
} }
} }

View File

@@ -17,6 +17,7 @@ import (
const ( const (
ForwardingFormatPrefix = "netbird-fwd-" ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t" ForwardingFormat = "netbird-fwd-%s-%t"
PreroutingFormat = "netbird-prerouting-%s-%t"
NatFormat = "netbird-nat-%s-%t" NatFormat = "netbird-nat-%s-%t"
) )

View File

@@ -520,7 +520,7 @@ func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
}, },
&expr.Immediate{ &expr.Immediate{
Register: 1, Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
}, },
&expr.Meta{ &expr.Meta{
Key: expr.MetaKeyMARK, Key: expr.MetaKeyMARK,
@@ -543,7 +543,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictJump, Kind: expr.VerdictJump,

View File

@@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
} }
// persist early // persist early
go func() {
if err := stateManager.PersistState(context.Background()); err != nil { if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
}()
return nil return nil
} }
@@ -197,7 +199,7 @@ func (m *Manager) AllowNetbird() error {
var chain *nftables.Chain var chain *nftables.Chain
for _, c := range chains { for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameForward { if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
chain = c chain = c
break break
} }
@@ -274,7 +276,7 @@ func (m *Manager) resetNetbirdInputRules() error {
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains { for _, c := range chains {
if c.Table.Name == "filter" && c.Name == "INPUT" { if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
rules, err := m.rConn.GetRules(c.Table, c) rules, err := m.rConn.GetRules(c.Table, c)
if err != nil { if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err) log.Errorf("get rules for chain %q: %v", c.Name, err)
@@ -349,7 +351,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
Register: 1, Register: 1,
Data: ifname(m.wgIface.Name()), Data: ifname(m.wgIface.Name()),
}, },
&expr.Verdict{}, &expr.Verdict{
Kind: expr.VerdictAccept,
},
}, },
UserData: []byte(allowNetbirdInputRuleID), UserData: []byte(allowNetbirdInputRuleID),
} }

View File

@@ -21,6 +21,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@@ -124,7 +125,6 @@ func (r *router) createContainers() error {
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
prio := *nftables.ChainPriorityNATSource - 1 prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat, Name: chainNameRoutingNat,
Table: r.workTable, Table: r.workTable,
@@ -133,6 +133,21 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT, Type: nftables.ChainTypeNAT,
}) })
// Chain is created by acl manager
// TODO: move creation to a common place
r.chains[chainNamePrerouting] = &nftables.Chain{
Name: chainNamePrerouting,
Table: r.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
}
// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err)
}
if err := r.acceptForwardRules(); err != nil { if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err) log.Errorf("failed to add accept rules for the forward chain: %s", err)
} }
@@ -422,59 +437,149 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source) sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination) destExp := generateCIDRMatcherExpressions(false, pair.Destination)
dir := expr.MetaKeyIIFNAME op := expr.CmpOpEq
notDir := expr.MetaKeyOIFNAME
if pair.Inverse { if pair.Inverse {
dir = expr.MetaKeyOIFNAME op = expr.CmpOpNeq
notDir = expr.MetaKeyIIFNAME
} }
lo := ifname("lo")
intf := ifname(r.wgIface.Name())
exprs := []expr.Any{ exprs := []expr.Any{
&expr.Meta{ // We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
Key: dir, // Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1, Register: 1,
}, },
&expr.Cmp{ &expr.Bitwise{
Op: expr.CmpOpEq, SourceRegister: 1,
Register: 1, DestRegister: 1,
Data: intf, Len: 4,
}, Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: notDir,
Register: 1,
}, },
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpNeq, Op: expr.CmpOpNeq,
Register: 1, Register: 1,
Data: lo, Data: []byte{0, 0, 0, 0},
},
// interface matching
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: op,
Register: 1,
Data: ifname(r.wgIface.Name()),
}, },
} }
exprs = append(exprs, sourceExp...) exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...) exprs = append(exprs, destExp...)
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
exprs = append(exprs, exprs = append(exprs,
&expr.Counter{}, &expr.Masq{}, &expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(markValue),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
) )
ruleKey := firewall.GenKey(firewall.NatFormat, pair) ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if _, exists := r.rules[ruleKey]; exists { if _, exists := r.rules[ruleKey]; exists {
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove routing rule: %w", err) return fmt.Errorf("remove prerouting rule: %w", err)
} }
} }
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable, Table: r.workTable,
Chain: r.chains[chainNameRoutingNat], Chain: r.chains[chainNamePrerouting],
Exprs: exprs, Exprs: exprs,
UserData: []byte(ruleKey), UserData: []byte(ruleKey),
}) })
return nil
}
// addPostroutingRules adds the masquerade rules
func (r *router) addPostroutingRules() error {
// First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{
// Match on the first fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
})
// Second masquerade rule for traffic going out through WireGuard interface
exprs2 := []expr.Any{
// Match on the second fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
},
// Match WireGuard interface
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2,
})
return nil return nil
} }
@@ -723,18 +828,18 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// RemoveNatRule removes a nftables rule pair from nat chains // RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err) return fmt.Errorf("remove prerouting rule: %w", err)
} }
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err) return fmt.Errorf("remove inverse prerouting rule: %w", err)
} }
if err := r.removeLegacyRouteRule(pair); err != nil { if err := r.removeLegacyRouteRule(pair); err != nil {
@@ -749,21 +854,20 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return nil return nil
} }
// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
func (r *router) removeNatRule(pair firewall.RouterPair) error { func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair) ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule) err := r.conn.DelRule(rule)
if err != nil { if err != nil {
return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination) log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} else { } else {
log.Debugf("nftables: nat rule %s not found", ruleKey) log.Debugf("nftables: prerouting rule %s not found", ruleKey)
} }
return nil return nil

View File

@@ -10,6 +10,7 @@ import (
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -32,100 +33,87 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
t.Skip("nftables not supported on this OS") t.Skip("nftables not supported on this OS")
} }
table, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases { for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(table, ifaceMock) // need fw manager to init both acl mgr and router for all chains to be present
require.NoError(t, err, "failed to create router") manager, err := Create(ifaceMock)
require.NoError(t, manager.init(table)) t.Cleanup(func() {
require.NoError(t, manager.Reset(nil))
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer func(manager *router) { rtr := manager.router
require.NoError(t, manager.Reset(), "failed to reset rules") err = rtr.AddNatRule(testCase.InputPair)
}(manager)
require.NoError(t, err, "shouldn't return error")
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "pair should be inserted") require.NoError(t, err, "pair should be inserted")
defer func(manager *router, pair firewall.RouterPair) { t.Cleanup(func() {
require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule") require.NoError(t, rtr.RemoveNatRule(testCase.InputPair), "failed to remove rule")
}(manager, testCase.InputPair) })
if testCase.InputPair.Masquerade { if testCase.InputPair.Masquerade {
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) // Build expected expressions for connection tracking
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) conntrackExprs := []expr.Any{
testingExpression := append(sourceExp, destExp...) //nolint:gocritic &expr.Ct{
testingExpression = append(testingExpression, Key: expr.CtKeySTATE,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
// Build interface matching expression
ifaceExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: ifname(ifaceMock.Name()), Data: ifname(ifaceMock.Name()),
}, },
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, }
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) // Build CIDR matching expressions
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
// Combine all expressions in the correct order
// nolint:gocritic
testingExpression := append(conntrackExprs, ifaceExprs...)
testingExpression = append(testingExpression, sourceExp...)
testingExpression = append(testingExpression, destExp...)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := 0 found := 0
for _, chain := range manager.chains { for _, chain := range rtr.chains {
if chain.Name == chainNamePrerouting {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain) rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match") // Compare expressions up to the mark setting expressions
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
found = 1 found = 1
} }
} }
} }
require.Equal(t, 1, found, "should find at least 1 rule to test")
} }
require.Equal(t, 1, found, "should find at least 1 rule in prerouting chain")
if testCase.InputPair.Masquerade {
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
testingExpression = append(testingExpression,
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
)
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
found = 1
} }
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
}) })
} }
} }
@@ -135,68 +123,66 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
t.Skip("nftables not supported on this OS") t.Skip("nftables not supported on this OS")
} }
table, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases { for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(table, ifaceMock) manager, err := Create(ifaceMock)
require.NoError(t, err, "failed to create router") t.Cleanup(func() {
require.NoError(t, manager.init(table)) require.NoError(t, manager.Reset(nil))
nftablesTestingClient := &nftables.Conn{}
defer func(manager *router) {
require.NoError(t, manager.Reset(), "failed to reset rules")
}(manager)
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(natRuleKey),
}) })
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source) rtr := manager.router
destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination)
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic // First add the NAT rule using the router's method
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule")
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{ // Verify the rule was added
Table: manager.workTable, natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
Chain: manager.chains[chainNameRoutingNat], found := false
Exprs: natExp, rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
UserData: []byte(inNatRuleKey), require.NoError(t, err, "should list rules")
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 { if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist") found = true
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist") break
} }
} }
require.True(t, found, "NAT rule should exist before removal")
// Now remove the rule
err = rtr.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error when removing rule")
// Verify the rule was removed
found = false
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules after removal")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
} }
}
require.False(t, found, "NAT rule should not exist after removal")
// Verify the static postrouting rules still exist
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameRoutingNat])
require.NoError(t, err, "should list postrouting rules")
foundCounter := false
for _, rule := range rules {
for _, e := range rule.Exprs {
if _, ok := e.(*expr.Counter); ok {
foundCounter = true
break
}
}
if foundCounter {
break
}
}
require.True(t, foundCounter, "static postrouting rule should remain")
}) })
} }
} }

View File

@@ -12,6 +12,7 @@ import (
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
) )
@@ -24,8 +25,8 @@ type receiverCreator struct {
iceBind *ICEBind iceBind *ICEBind
} }
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn) return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
} }
// ICEBind is a bind implementation with two main features: // ICEBind is a bind implementation with two main features:
@@ -154,7 +155,7 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
return nil return nil
} }
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock() s.muUDPMux.Lock()
defer s.muUDPMux.Unlock() defer s.muUDPMux.Unlock()
@@ -166,17 +167,31 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
}, },
) )
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message) msgs := getMessages(msgsPool)
defer ipv4MsgsPool.Put(msgs)
for i := range bufs { for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i] (*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
} }
defer putMessages(msgs, msgsPool)
var numMsgs int var numMsgs int
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
//nolint
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = wgConn.SplitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
if err != nil {
return 0, err
}
} else {
numMsgs, err = pc.ReadBatch(*msgs, 0) numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil { if err != nil {
return 0, err return 0, err
} }
}
} else { } else {
msg := &(*msgs)[0] msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
@@ -191,11 +206,12 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
// todo: handle err // todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr) ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok { if ok {
sizes[i] = 0 continue
} else { }
sizes[i] = msg.N sizes[i] = msg.N
if sizes[i] == 0 {
continue
} }
addrPort := msg.Addr.(*net.UDPAddr).AddrPort() addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep) wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
@@ -273,3 +289,15 @@ func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
} }
return newAddr, nil return newAddr, nil
} }
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
return msgsPool.Get().(*[]ipv6.Message)
}
func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
for i := range *msgs {
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
}
msgsPool.Put(msgs)
}

View File

@@ -2,6 +2,7 @@ package bind
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -94,7 +95,10 @@ func (p *ProxyBind) close() error {
p.Bind.RemoveEndpoint(p.wgAddr) p.Bind.RemoveEndpoint(p.wgAddr)
return p.remoteConn.Close() if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr
}
return nil
} }
func (p *ProxyBind) proxyToLocal(ctx context.Context) { func (p *ProxyBind) proxyToLocal(ctx context.Context) {

View File

@@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error {
e.cancel() e.cancel()
if err := e.remoteConn.Close(); err != nil { if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err) return fmt.Errorf("failed to close remote conn: %w", err)
} }
return nil return nil

View File

@@ -116,7 +116,7 @@ func (p *WGUDPProxy) close() error {
p.cancel() p.cancel()
var result *multierror.Error var result *multierror.Error
if err := p.remoteConn.Close(); err != nil { if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
} }

View File

@@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg) err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
return cfg, err return cfg, err
} }
@@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
// WriteOutConfig write put the prepared config to the given path // WriteOutConfig write put the prepared config to the given path
func WriteOutConfig(path string, config *Config) error { func WriteOutConfig(path string, config *Config) error {
return util.WriteJson(path, config) return util.WriteJson(context.Background(), path, config)
} }
// createNewConfig creates a new config generating a new Wireguard key and saving to file // createNewConfig creates a new config generating a new Wireguard key and saving to file
@@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
} }
if updated { if updated {
if err := util.WriteJson(input.ConfigPath, config); err != nil { if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
return nil, err return nil, err
} }
} }

View File

@@ -157,7 +157,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
engineCtx, cancel := context.WithCancel(c.ctx) engineCtx, cancel := context.WithCancel(c.ctx)
defer func() { defer func() {
c.statusRecorder.MarkManagementDisconnected(state.err) _, err := state.Status()
c.statusRecorder.MarkManagementDisconnected(err)
c.statusRecorder.CleanLocalPeerState() c.statusRecorder.CleanLocalPeerState()
cancel() cancel()
}() }()
@@ -207,7 +208,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
c.statusRecorder.MarkSignalDisconnected(nil) c.statusRecorder.MarkSignalDisconnected(nil)
defer func() { defer func() {
c.statusRecorder.MarkSignalDisconnected(state.err) _, err := state.Status()
c.statusRecorder.MarkSignalDisconnected(err)
}() }()
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal // with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal

View File

@@ -7,7 +7,6 @@ import (
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2" "github.com/mitchellh/hashstructure/v2"
@@ -323,12 +322,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
log.Error(err) log.Error(err)
} }
go func() {
// persist dns state right away // persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) if err := s.stateManager.PersistState(s.ctx); err != nil {
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err) log.Errorf("Failed to persist dns state: %v", err)
} }
}()
if s.searchDomainNotifier != nil { if s.searchDomainNotifier != nil {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
@@ -533,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks(
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
} }
// persist dns state right away go func() {
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) if err := s.stateManager.PersistState(s.ctx); err != nil {
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
l.Errorf("Failed to persist dns state: %v", err) l.Errorf("Failed to persist dns state: %v", err)
} }
}()
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone() s.addHostRootZone()

View File

@@ -782,7 +782,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
Port: 53, Port: 53,
}, },
}, },
Domains: []string{"customdomain.com"}, Domains: []string{"google.com"},
Primary: false, Primary: false,
}, },
}, },
@@ -804,7 +804,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
if ips[0] != zoneRecords[0].RData { if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err) t.Fatalf("invalid zone record: %v", err)
} }
_, err = resolver.LookupHost(context.Background(), "customdomain.com") _, err = resolver.LookupHost(context.Background(), "google.com")
if err != nil { if err != nil {
t.Errorf("failed to resolve: %s", err) t.Errorf("failed to resolve: %s", err)
} }

View File

@@ -11,6 +11,7 @@ import (
"reflect" "reflect"
"runtime" "runtime"
"slices" "slices"
"sort"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -38,7 +39,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@@ -297,7 +297,7 @@ func (e *Engine) Stop() error {
if err := e.stateManager.Stop(ctx); err != nil { if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err) return fmt.Errorf("failed to stop state manager: %w", err)
} }
if err := e.stateManager.PersistState(ctx); err != nil { if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
@@ -641,6 +641,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} }
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if e.wgInterface == nil {
return errors.New("wireguard interface is not initialized")
}
if e.wgInterface.Address().String() != conf.Address { if e.wgInterface.Address().String() != conf.Address {
oldAddr := e.wgInterface.Address().String() oldAddr := e.wgInterface.Address().String()
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address) log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
@@ -1481,6 +1485,17 @@ func (e *Engine) stopDNSServer() {
// isChecksEqual checks if two slices of checks are equal. // isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
for _, check := range checks {
sort.Slice(check.Files, func(i, j int) bool {
return check.Files[i] < check.Files[j]
})
}
for _, oCheck := range oChecks {
sort.Slice(oCheck.Files, func(i, j int) bool {
return oCheck.Files[i] < oCheck.Files[j]
})
}
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool { return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files) return slices.Equal(checks.Files, oChecks.Files)
}) })

View File

@@ -1006,6 +1006,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
} }
} }
func Test_CheckFilesEqual(t *testing.T) {
testCases := []struct {
name string
inputChecks1 []*mgmtProto.Checks
inputChecks2 []*mgmtProto.Checks
expectedBool bool
}{
{
name: "Equal Files In Equal Order Should Return True",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
expectedBool: true,
},
{
name: "Equal Files In Reverse Order Should Return True",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile2",
"testfile1",
},
},
},
expectedBool: true,
},
{
name: "Unequal Files Should Return False",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile3",
},
},
},
expectedBool: false,
},
{
name: "Compared With Empty Should Return False",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{},
},
},
expectedBool: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
})
}
}
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) { func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {

View File

@@ -442,7 +442,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
if conn.iceP2PIsActive() { if conn.iceP2PIsActive() {
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
conn.wgProxyRelay = wgProxy conn.setRelayedProxy(wgProxy)
conn.statusRelay.Set(StatusConnected) conn.statusRelay.Set(StatusConnected)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
return return
@@ -465,7 +465,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
wgConfigWorkaround() wgConfigWorkaround()
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = connPriorityRelay
conn.statusRelay.Set(StatusConnected) conn.statusRelay.Set(StatusConnected)
conn.wgProxyRelay = wgProxy conn.setRelayedProxy(wgProxy)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.log.Infof("start to communicate with peer via relay") conn.log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
@@ -736,6 +736,15 @@ func (conn *Conn) logTraceConnState() {
} }
} }
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyRelay = proxy
}
func isController(config ConnConfig) bool { func isController(config ConnConfig) bool {
return config.LocalKey > config.Key return config.LocalKey > config.Key
} }

View File

@@ -67,7 +67,7 @@ func (s *State) DeleteRoute(network string) {
func (s *State) GetRoutes() map[string]struct{} { func (s *State) GetRoutes() map[string]struct{} {
s.Mux.RLock() s.Mux.RLock()
defer s.Mux.RUnlock() defer s.Mux.RUnlock()
return s.routes return maps.Clone(s.routes)
} }
// LocalPeerState contains the latest state of the local peer // LocalPeerState contains the latest state of the local peer
@@ -237,10 +237,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP peerState.IP = receivedState.IP
} }
if receivedState.GetRoutes() != nil {
peerState.SetRoutes(receivedState.GetRoutes())
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState) skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
if receivedState.ConnStatus != peerState.ConnStatus { if receivedState.ConnStatus != peerState.ConnStatus {
@@ -261,12 +257,40 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil return nil
} }
ch, found := d.changeNotify[receivedState.PubKey] d.notifyPeerListChanged()
if found && ch != nil { return nil
close(ch) }
d.changeNotify[receivedState.PubKey] = nil
func (d *Status) AddPeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer]
if !ok {
return errors.New("peer doesn't exist")
} }
peerState.AddRoute(route)
d.peers[peer] = peerState
// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
return nil
}
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer]
if !ok {
return errors.New("peer doesn't exist")
}
peerState.DeleteRoute(route)
d.peers[peer] = peerState
// todo: consider to make sense of this notification or not
d.notifyPeerListChanged() d.notifyPeerListChanged()
return nil return nil
} }
@@ -301,12 +325,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
return nil return nil
} }
ch, found := d.changeNotify[receivedState.PubKey] d.notifyPeerStateChangeListeners(receivedState.PubKey)
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerListChanged() d.notifyPeerListChanged()
return nil return nil
} }
@@ -334,12 +353,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
return nil return nil
} }
ch, found := d.changeNotify[receivedState.PubKey] d.notifyPeerStateChangeListeners(receivedState.PubKey)
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerListChanged() d.notifyPeerListChanged()
return nil return nil
} }
@@ -366,12 +380,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
return nil return nil
} }
ch, found := d.changeNotify[receivedState.PubKey] d.notifyPeerStateChangeListeners(receivedState.PubKey)
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerListChanged() d.notifyPeerListChanged()
return nil return nil
} }
@@ -401,12 +410,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
return nil return nil
} }
ch, found := d.changeNotify[receivedState.PubKey] d.notifyPeerStateChangeListeners(receivedState.PubKey)
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerListChanged() d.notifyPeerListChanged()
return nil return nil
} }
@@ -477,11 +481,14 @@ func (d *Status) FinishPeerListModifications() {
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} { func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
ch, found := d.changeNotify[peer] ch, found := d.changeNotify[peer]
if !found || ch == nil { if found {
return ch
}
ch = make(chan struct{}) ch = make(chan struct{})
d.changeNotify[peer] = ch d.changeNotify[peer] = ch
}
return ch return ch
} }
@@ -755,6 +762,17 @@ func (d *Status) onConnectionChanged() {
d.notifier.updateServerStates(d.managementState, d.signalState) d.notifier.updateServerStates(d.managementState, d.signalState)
} }
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
ch, found := d.changeNotify[peerID]
if !found {
return
}
close(ch)
delete(d.changeNotify, peerID)
}
func (d *Status) notifyPeerListChanged() { func (d *Status) notifyPeerListChanged() {
d.notifier.peerListChanged(d.numOfPeers()) d.notifier.peerListChanged(d.numOfPeers())
} }

View File

@@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
peerState.IP = ip peerState.IP = ip
err := status.UpdatePeerState(peerState) err := status.UpdatePeerRelayedStateToDisconnected(peerState)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
select { select {

View File

@@ -57,6 +57,9 @@ type WorkerICE struct {
localUfrag string localUfrag string
localPwd string localPwd string
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
lastKnownState ice.ConnectionState
} }
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) { func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
@@ -194,8 +197,7 @@ func (w *WorkerICE) Close() {
return return
} }
err := w.agent.Close() if err := w.agent.Close(); err != nil {
if err != nil {
w.log.Warnf("failed to close ICE agent: %s", err) w.log.Warnf("failed to close ICE agent: %s", err)
} }
} }
@@ -215,15 +217,18 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) { err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String()) w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected { switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState != ice.ConnectionStateDisconnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.OnStatusChanged(StatusDisconnected) w.conn.OnStatusChanged(StatusDisconnected)
}
w.muxAgent.Lock() w.closeAgent(agentCancel)
agentCancel() default:
_ = agent.Close() return
w.agent = nil
w.muxAgent.Unlock()
} }
}) })
if err != nil { if err != nil {
@@ -249,6 +254,17 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
return agent, nil return agent, nil
} }
func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
cancel()
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
w.agent = nil
}
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
// wait local endpoint configuration // wait local endpoint configuration
time.Sleep(time.Second) time.Sleep(time.Second)

View File

@@ -122,13 +122,20 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
tempScore = float64(metricDiff) * 10 tempScore = float64(metricDiff) * 10
} }
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route // in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route
latency := time.Second latency := 999 * time.Millisecond
if peerStatus.latency != 0 { if peerStatus.latency != 0 {
latency = peerStatus.latency latency = peerStatus.latency
} else { } else {
log.Warnf("peer %s has 0 latency", r.Peer) log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler)
} }
// avoid negative tempScore on the higher latency calculation
if latency > 1*time.Second {
latency = 999 * time.Millisecond
}
// higher latency is worse score
tempScore += 1 - latency.Seconds() tempScore += 1 - latency.Seconds()
if !peerStatus.relayed { if !peerStatus.relayed {
@@ -150,6 +157,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
} }
} }
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)
switch { switch {
case chosen == "": case chosen == "":
var peers []string var peers []string
@@ -195,15 +204,20 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
func (c *clientNetwork) startPeersStatusChangeWatcher() { func (c *clientNetwork) startPeersStatusChangeWatcher() {
for _, r := range c.routes { for _, r := range c.routes {
_, found := c.routePeersNotifiers[r.Peer] _, found := c.routePeersNotifiers[r.Peer]
if !found { if found {
c.routePeersNotifiers[r.Peer] = make(chan struct{}) continue
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer])
} }
closerChan := make(chan struct{})
c.routePeersNotifiers[r.Peer] = closerChan
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
} }
} }
func (c *clientNetwork) removeRouteFromWireguardPeer() error { func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
c.removeStateRoute() if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
if err := c.handler.RemoveAllowedIPs(); err != nil { if err := c.handler.RemoveAllowedIPs(); err != nil {
return fmt.Errorf("remove allowed IPs: %w", err) return fmt.Errorf("remove allowed IPs: %w", err)
@@ -218,7 +232,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
var merr *multierror.Error var merr *multierror.Error
if err := c.removeRouteFromWireguardPeer(); err != nil { if err := c.removeRouteFromWireGuardPeer(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)) merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
} }
if err := c.handler.RemoveRoute(); err != nil { if err := c.handler.RemoveRoute(); err != nil {
@@ -257,7 +271,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
} }
} else { } else {
// Otherwise, remove the allowed IPs from the previous peer first // Otherwise, remove the allowed IPs from the previous peer first
if err := c.removeRouteFromWireguardPeer(); err != nil { if err := c.removeRouteFromWireGuardPeer(); err != nil {
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err) return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
} }
} }
@@ -268,37 +282,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err) return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
} }
c.addStateRoute() err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
if err != nil {
return fmt.Errorf("add peer state route: %w", err)
}
return nil return nil
} }
func (c *clientNetwork) addStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.AddRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) removeStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.DeleteRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) { func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
go func() { go func() {
c.routeUpdate <- update c.routeUpdate <- update

View File

@@ -1,6 +1,7 @@
package routemanager package routemanager
import ( import (
"fmt"
"net/netip" "net/netip"
"testing" "testing"
"time" "time"
@@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
currentRoute: "route1", currentRoute: "route1",
expectedRouteID: "route1", expectedRouteID: "route1",
}, },
{
name: "relayed routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: true,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
relayed: true,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "p2p routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{ {
name: "current route with bad score should be changed to route with better score", name: "current route with bad score should be changed to route with better score",
statuses: map[route.ID]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
@@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
} }
// fill the test data with random routes
for _, tc := range testCases {
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
dummyStatus := routerPeerStatus{
connected: false,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
dummyStatus := routerPeerStatus{
connected: false,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
currentRoute := &route.Route{ currentRoute := &route.Route{

View File

@@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
type Counter[Key comparable, I, O any] struct { type Counter[Key comparable, I, O any] struct {
// refCountMap keeps track of the reference Ref for keys // refCountMap keeps track of the reference Ref for keys
refCountMap map[Key]Ref[O] refCountMap map[Key]Ref[O]
refCountMu sync.Mutex mu sync.Mutex
// idMap keeps track of the keys associated with an ID for removal // idMap keeps track of the keys associated with an ID for removal
idMap map[string][]Key idMap map[string][]Key
idMu sync.Mutex
add AddFunc[Key, I, O] add AddFunc[Key, I, O]
remove RemoveFunc[Key, O] remove RemoveFunc[Key, O]
} }
@@ -75,10 +74,8 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
func (rm *Counter[Key, I, O]) LoadData( func (rm *Counter[Key, I, O]) LoadData(
existingCounter *Counter[Key, I, O], existingCounter *Counter[Key, I, O],
) { ) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.refCountMap = existingCounter.refCountMap rm.refCountMap = existingCounter.refCountMap
rm.idMap = existingCounter.idMap rm.idMap = existingCounter.idMap
@@ -87,8 +84,8 @@ func (rm *Counter[Key, I, O]) LoadData(
// Get retrieves the current reference count and associated data for a key. // Get retrieves the current reference count and associated data for a key.
// If the key doesn't exist, it returns a zero value Ref and false. // If the key doesn't exist, it returns a zero value Ref and false.
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
ref, ok := rm.refCountMap[key] ref, ok := rm.refCountMap[key]
return ref, ok return ref, ok
@@ -97,9 +94,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
// Increment increments the reference count for the given key. // Increment increments the reference count for the given key.
// If this is the first reference to the key, the AddFunc is called. // If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
return rm.increment(key, in)
}
func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
ref := rm.refCountMap[key] ref := rm.refCountMap[key]
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out) logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
@@ -126,10 +127,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
// IncrementWithID increments the reference count for the given key and groups it under the given ID. // IncrementWithID increments the reference count for the given key and groups it under the given ID.
// If this is the first reference to the key, the AddFunc is called. // If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) { func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
rm.idMu.Lock() rm.mu.Lock()
defer rm.idMu.Unlock() defer rm.mu.Unlock()
ref, err := rm.Increment(key, in) ref, err := rm.increment(key, in)
if err != nil { if err != nil {
return ref, fmt.Errorf("with ID: %w", err) return ref, fmt.Errorf("with ID: %w", err)
} }
@@ -141,9 +142,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O],
// Decrement decrements the reference count for the given key. // Decrement decrements the reference count for the given key.
// If the reference count reaches 0, the RemoveFunc is called. // If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
return rm.decrement(key)
}
func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
ref, ok := rm.refCountMap[key] ref, ok := rm.refCountMap[key]
if !ok { if !ok {
logCallerF("No reference found for key %v", key) logCallerF("No reference found for key %v", key)
@@ -168,12 +172,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
// DecrementWithID decrements the reference count for all keys associated with the given ID. // DecrementWithID decrements the reference count for all keys associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called. // If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
rm.idMu.Lock() rm.mu.Lock()
defer rm.idMu.Unlock() defer rm.mu.Unlock()
var merr *multierror.Error var merr *multierror.Error
for _, key := range rm.idMap[id] { for _, key := range rm.idMap[id] {
if _, err := rm.Decrement(key); err != nil { if _, err := rm.decrement(key); err != nil {
merr = multierror.Append(merr, err) merr = multierror.Append(merr, err)
} }
} }
@@ -184,10 +188,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
// Flush removes all references and calls RemoveFunc for each key. // Flush removes all references and calls RemoveFunc for each key.
func (rm *Counter[Key, I, O]) Flush() error { func (rm *Counter[Key, I, O]) Flush() error {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error var merr *multierror.Error
for key := range rm.refCountMap { for key := range rm.refCountMap {
@@ -206,10 +208,8 @@ func (rm *Counter[Key, I, O]) Flush() error {
// Clear removes all references without calling RemoveFunc. // Clear removes all references without calling RemoveFunc.
func (rm *Counter[Key, I, O]) Clear() { func (rm *Counter[Key, I, O]) Clear() {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
clear(rm.refCountMap) clear(rm.refCountMap)
clear(rm.idMap) clear(rm.idMap)
@@ -217,6 +217,9 @@ func (rm *Counter[Key, I, O]) Clear() {
// MarshalJSON implements the json.Marshaler interface for Counter. // MarshalJSON implements the json.Marshaler interface for Counter.
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
rm.mu.Lock()
defer rm.mu.Unlock()
return json.Marshal(struct { return json.Marshal(struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"` RefCountMap map[Key]Ref[O] `json:"refCountMap"`
IDMap map[string][]Key `json:"idMap"` IDMap map[string][]Key `json:"idMap"`

View File

@@ -2,31 +2,28 @@ package systemops
import ( import (
"net/netip" "net/netip"
"sync"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
) )
type ShutdownState struct { type ShutdownState ExclusionCounter
Counter *ExclusionCounter `json:"counter,omitempty"`
mu sync.RWMutex
}
func (s *ShutdownState) Name() string { func (s *ShutdownState) Name() string {
return "route_state" return "route_state"
} }
func (s *ShutdownState) Cleanup() error { func (s *ShutdownState) Cleanup() error {
s.mu.RLock()
defer s.mu.RUnlock()
if s.Counter == nil {
return nil
}
sysops := NewSysOps(nil, nil) sysops := NewSysOps(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
sysops.refCounter.LoadData(s.Counter) sysops.refCounter.LoadData((*ExclusionCounter)(s))
return sysops.refCounter.Flush() return sysops.refCounter.Flush()
} }
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
return (*ExclusionCounter)(s).MarshalJSON()
}
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
return (*ExclusionCounter)(s).UnmarshalJSON(data)
}

View File

@@ -57,30 +57,19 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return nexthop, refcounter.ErrIgnore return nexthop, refcounter.ErrIgnore
} }
r.updateState(stateManager)
return nexthop, err return nexthop, err
}, },
func(prefix netip.Prefix, nexthop Nexthop) error { r.removeFromRouteTable,
// remove from state even if we have trouble removing it from the route table
// it could be already gone
r.updateState(stateManager)
return r.removeFromRouteTable(prefix, nexthop)
},
) )
r.refCounter = refCounter r.refCounter = refCounter
return r.setupHooks(initAddresses) return r.setupHooks(initAddresses, stateManager)
} }
// updateState updates state on every change so it will be persisted regularly
func (r *SysOps) updateState(stateManager *statemanager.Manager) { func (r *SysOps) updateState(stateManager *statemanager.Manager) {
state := getState(stateManager) if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil {
state.Counter = r.refCounter
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
} }
} }
@@ -336,7 +325,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
return r.removeFromRouteTable(prefix, nextHop) return r.removeFromRouteTable(prefix, nextHop)
} }
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip) prefix, err := util.GetPrefixFromIP(ip)
if err != nil { if err != nil {
@@ -347,6 +336,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
return fmt.Errorf("adding route reference: %v", err) return fmt.Errorf("adding route reference: %v", err)
} }
r.updateState(stateManager)
return nil return nil
} }
afterHook := func(connID nbnet.ConnectionID) error { afterHook := func(connID nbnet.ConnectionID) error {
@@ -354,6 +345,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
return fmt.Errorf("remove route reference: %w", err) return fmt.Errorf("remove route reference: %w", err)
} }
r.updateState(stateManager)
return nil return nil
} }
@@ -532,14 +525,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
// Return true if the longest matching prefix is from vpnRoutes // Return true if the longest matching prefix is from vpnRoutes
return isVpn, longestPrefix return isVpn, longestPrefix
} }
func getState(stateManager *statemanager.Manager) *ShutdownState {
var shutdownState *ShutdownState
if state := stateManager.GetState(shutdownState); state != nil {
shutdownState = state.(*ShutdownState)
} else {
shutdownState = &ShutdownState{}
}
return shutdownState
}

View File

@@ -55,7 +55,7 @@ type ruleParams struct {
// isLegacy determines whether to use the legacy routing setup // isLegacy determines whether to use the legacy routing setup
func isLegacy() bool { func isLegacy() bool {
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true"
} }
// setIsLegacy sets the legacy routing setup // setIsLegacy sets the legacy routing setup

View File

@@ -16,6 +16,7 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/util"
) )
// State interface defines the methods that all state types must implement // State interface defines the methods that all state types must implement
@@ -73,15 +74,15 @@ func (m *Manager) Stop(ctx context.Context) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if m.cancel != nil { if m.cancel == nil {
return nil
}
m.cancel() m.cancel()
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-m.done: case <-m.done:
return nil
}
} }
return nil return nil
@@ -178,25 +179,18 @@ func (m *Manager) PersistState(ctx context.Context) error {
return nil return nil
} }
ctx, cancel := context.WithTimeout(ctx, 3*time.Second) bs, err := marshalWithPanicRecovery(m.states)
if err != nil {
return fmt.Errorf("marshal states: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel() defer cancel()
done := make(chan error, 1) done := make(chan error, 1)
start := time.Now()
go func() { go func() {
data, err := json.MarshalIndent(m.states, "", " ") done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
if err != nil {
done <- fmt.Errorf("marshal states: %w", err)
return
}
// nolint:gosec
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
done <- fmt.Errorf("write state file: %w", err)
return
}
done <- nil
}() }()
select { select {
@@ -208,7 +202,7 @@ func (m *Manager) PersistState(ctx context.Context) error {
} }
} }
log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty)) log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
clear(m.dirty) clear(m.dirty)
@@ -296,3 +290,19 @@ func (m *Manager) PerformCleanup() error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
func marshalWithPanicRecovery(v any) ([]byte, error) {
var bs []byte
var err error
func() {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during marshal: %v", r)
}
}()
bs, err = json.Marshal(v)
}()
return bs, err
}

View File

@@ -4,32 +4,20 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
log "github.com/sirupsen/logrus"
) )
// GetDefaultStatePath returns the path to the state file based on the operating system // GetDefaultStatePath returns the path to the state file based on the operating system
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist. // It returns an empty string if the path cannot be determined.
func GetDefaultStatePath() string { func GetDefaultStatePath() string {
var path string
switch runtime.GOOS { switch runtime.GOOS {
case "windows": case "windows":
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
case "darwin", "linux": case "darwin", "linux":
path = "/var/lib/netbird/state.json" return "/var/lib/netbird/state.json"
case "freebsd", "openbsd", "netbsd", "dragonfly": case "freebsd", "openbsd", "netbsd", "dragonfly":
path = "/var/db/netbird/state.json" return "/var/db/netbird/state.json"
// ios/android don't need state
default:
return ""
} }
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
return "" return ""
}
return path
} }

View File

@@ -1,13 +1,41 @@
package main package main
import ( import (
"net/http"
_ "net/http/pprof"
"os" "os"
"path/filepath"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/cmd" "github.com/netbirdio/netbird/client/cmd"
) )
func main() { func main() {
// only if env is not set
if os.Getenv("NB_LOG_LEVEL") == "" {
if err := os.Setenv("NB_LOG_LEVEL", "debug"); err != nil {
log.Errorf("Failed setting log-level: %v", err)
}
}
if err := os.Setenv("NB_LOG_MAX_SIZE_MB", "100"); err != nil {
log.Errorf("Failed setting log-size: %v", err)
}
if err := os.Setenv("NB_WINDOWS_PANIC_LOG", filepath.Join(os.Getenv("ProgramData"), "netbird", "netbird.err")); err != nil {
log.Errorf("Failed setting panic log path: %v", err)
}
go startPprofServer()
if err := cmd.Execute(); err != nil { if err := cmd.Execute(); err != nil {
os.Exit(1) os.Exit(1)
} }
} }
func startPprofServer() {
pprofAddr := "localhost:6969"
log.Infof("Starting pprof debugging server on %s", pprofAddr)
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
log.Infof("pprof server failed: %v", err)
}
}

View File

@@ -120,6 +120,35 @@ func (s *Server) Start() error {
ctx, cancel := context.WithCancel(s.rootCtx) ctx, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel s.actCancel = cancel
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if statusResp, err := s.Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true}); err != nil {
log.Infof("Error getting status: %v", err)
} else if statusResp.FullStatus != nil {
log.Infof("Status --------")
for _, peer := range statusResp.FullStatus.Peers {
log.Infof("[Peer Connection] Name: %s, IP: %s, Key: %s, Connection Status: %s, Relayed: %v, RelayedAddress: %v, Last WireGuard Handshake: %v",
peer.Fqdn,
peer.IP,
peer.PubKey,
peer.ConnStatus,
peer.Relayed,
peer.RelayAddress,
peer.LastWireguardHandshake.AsTime().Format("15:04:05"),
)
}
}
}
}
}()
// if configuration exists, we just start connections. if is new config we skip and set status NeedsLogin // if configuration exists, we just start connections. if is new config we skip and set status NeedsLogin
// on failure we return error to retry // on failure we return error to retry
config, err := internal.UpdateConfig(s.latestConfigInput) config, err := internal.UpdateConfig(s.latestConfigInput)
@@ -626,6 +655,8 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
s.oauthAuthFlow = oauthAuthFlow{}
if s.actCancel == nil { if s.actCancel == nil {
return nil, fmt.Errorf("service is not up") return nil, fmt.Errorf("service is not up")
} }

4
go.mod
View File

@@ -60,7 +60,7 @@ require (
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0
@@ -236,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

8
go.sum
View File

@@ -521,14 +521,14 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c= github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 h1:L8mNd3tBxMdnQNxMNJ+/EiwHwizNOMy8/nHLVGNfjpg=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4= github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=

View File

@@ -110,7 +110,6 @@ type AccountManager interface {
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
@@ -966,7 +965,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro
} }
// UserGroupsAddToPeers adds groups to all peers of user // UserGroupsAddToPeers adds groups to all peers of user
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) { func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string {
groupUpdates := make(map[string][]string)
userPeers := make(map[string]struct{}) userPeers := make(map[string]struct{})
for pid, peer := range a.Peers { for pid, peer := range a.Peers {
if peer.UserID == userID { if peer.UserID == userID {
@@ -980,6 +981,8 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
continue continue
} }
oldPeers := group.Peers
groupPeers := make(map[string]struct{}) groupPeers := make(map[string]struct{})
for _, pid := range group.Peers { for _, pid := range group.Peers {
groupPeers[pid] = struct{}{} groupPeers[pid] = struct{}{}
@@ -993,16 +996,25 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
for pid := range groupPeers { for pid := range groupPeers {
group.Peers = append(group.Peers, pid) group.Peers = append(group.Peers, pid)
} }
groupUpdates[gid] = difference(group.Peers, oldPeers)
} }
return groupUpdates
} }
// UserGroupsRemoveFromPeers removes groups from all peers of user // UserGroupsRemoveFromPeers removes groups from all peers of user
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string {
groupUpdates := make(map[string][]string)
for _, gid := range groups { for _, gid := range groups {
group, ok := a.Groups[gid] group, ok := a.Groups[gid]
if !ok || group.Name == "All" { if !ok || group.Name == "All" {
continue continue
} }
oldPeers := group.Peers
update := make([]string, 0, len(group.Peers)) update := make([]string, 0, len(group.Peers))
for _, pid := range group.Peers { for _, pid := range group.Peers {
peer, ok := a.Peers[pid] peer, ok := a.Peers[pid]
@@ -1014,7 +1026,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
} }
} }
group.Peers = update group.Peers = update
groupUpdates[gid] = difference(oldPeers, group.Peers)
} }
return groupUpdates
} }
// BuildManager creates a new DefaultAccountManager with a provided Store // BuildManager creates a new DefaultAccountManager with a provided Store
@@ -1176,6 +1191,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, err return nil, err
} }
err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
if err != nil {
return nil, fmt.Errorf("groups propagation failed: %w", err)
}
updatedAccount := account.UpdateSettings(newSettings) updatedAccount := account.UpdateSettings(newSettings)
err = am.Store.SaveAccount(ctx, account) err = am.Store.SaveAccount(ctx, account)
@@ -1186,7 +1206,29 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return updatedAccount, nil return updatedAccount, nil
} }
func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error {
if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled {
if newSettings.GroupsPropagationEnabled {
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil)
// Todo: retroactively add user groups to all peers
} else {
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil)
}
}
return nil
}
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled {
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
} else {
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
event := activity.AccountPeerInactivityExpirationEnabled event := activity.AccountPeerInactivityExpirationEnabled
if !newSettings.PeerInactivityExpirationEnabled { if !newSettings.PeerInactivityExpirationEnabled {
@@ -1197,10 +1239,6 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.
} }
am.StoreEvent(ctx, userID, accountID, accountID, event, nil) am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
} }
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
} }
return nil return nil
@@ -1249,7 +1287,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
log.Errorf("failed getting account %s expiring peers", account.Id) log.Errorf("failed getting account %s expiring peers", accountID)
return account.GetNextInactivePeerExpiration() return account.GetNextInactivePeerExpiration()
} }
@@ -1435,7 +1473,7 @@ func isNil(i idp.Manager) bool {
// addAccountIDToIDPAppMeta update user's app metadata in idp manager // addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID) accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -2029,7 +2067,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error getting user: %w", err) return fmt.Errorf("error getting user: %w", err)
} }
groups, err := transaction.GetAccountGroups(ctx, accountID) groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return fmt.Errorf("error getting account groups: %w", err) return fmt.Errorf("error getting account groups: %w", err)
} }
@@ -2059,7 +2097,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
// Propagate changes to peers if group propagation is enabled // Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
groups, err = transaction.GetAccountGroups(ctx, accountID) groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return fmt.Errorf("error getting account groups: %w", err) return fmt.Errorf("error getting account groups: %w", err)
} }
@@ -2083,7 +2121,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error saving groups: %w", err) return fmt.Errorf("error saving groups: %w", err)
} }
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err) return fmt.Errorf("error incrementing network serial: %w", err)
} }
} }
@@ -2101,7 +2139,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
for _, g := range addNewGroups { for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else { } else {
@@ -2114,7 +2152,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
for _, g := range removeOldGroups { for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else { } else {
@@ -2127,14 +2165,19 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
if err != nil { if err != nil {
return fmt.Errorf("error getting account: %w", err) return err
} }
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) { newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
if err != nil {
return err
}
if removedGroupAffectsPeers || newGroupsAffectsPeers {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
} }
@@ -2290,12 +2333,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, status.NewGetAccountError(err)
} }
peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account) peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
} }
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account) err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account)
@@ -2314,12 +2357,12 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return status.NewGetAccountError(err)
} }
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account) err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
} }
return nil return nil
@@ -2335,6 +2378,9 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
unlock := am.Store.AcquireReadLockByUID(ctx, accountID) unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer unlock() defer unlock()
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer unlockPeer()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
@@ -2398,12 +2444,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context,
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
updatedAccount, err := am.Store.GetAccount(ctx, accountID) am.updateAccountPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
return
}
am.updateAccountPeers(ctx, updatedAccount)
} }
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {

View File

@@ -29,14 +29,18 @@ import (
) )
type MocIntegratedValidator struct { type MocIntegratedValidator struct {
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
} }
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil return nil
} }
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, nil if a.ValidatePeerFunc != nil {
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
}
return update, false, nil
} }
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{}) validatedPeers := make(map[string]struct{})
@@ -978,6 +982,110 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
} }
} }
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
claims := jwtclaims.AuthorizationClaims{
Domain: "example.com",
UserId: "pvt-domain-user",
DomainCategory: PrivateCategory,
}
publicClaims := jwtclaims.AuthorizationClaims{
Domain: "test.com",
UserId: "public-domain-user",
DomainCategory: PublicCategory,
}
am, err := createManager(b)
if err != nil {
b.Fatal(err)
return
}
id, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil {
b.Fatal(err)
}
pid, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
if err != nil {
b.Fatal(err)
}
users := genUsers("priv", 100)
acc, err := am.Store.GetAccount(context.Background(), id)
if err != nil {
b.Fatal(err)
}
acc.Users = users
err = am.Store.SaveAccount(context.Background(), acc)
if err != nil {
b.Fatal(err)
}
userP := genUsers("pub", 100)
pacc, err := am.Store.GetAccount(context.Background(), pid)
if err != nil {
b.Fatal(err)
}
pacc.Users = userP
err = am.Store.SaveAccount(context.Background(), pacc)
if err != nil {
b.Fatal(err)
}
b.Run("public without account ID", func(b *testing.B) {
//b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("private without account ID", func(b *testing.B) {
//b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("private with account ID", func(b *testing.B) {
claims.AccountId = id
//b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil {
b.Fatal(err)
}
}
})
}
func genUsers(p string, n int) map[string]*User {
users := map[string]*User{}
now := time.Now()
for i := 0; i < n; i++ {
users[fmt.Sprintf("%s-%d", p, i)] = &User{
Id: fmt.Sprintf("%s-%d", p, i),
Role: UserRoleAdmin,
LastLogin: now,
CreatedAt: now,
Issued: "api",
AutoGroups: []string{"one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"},
}
}
return users
}
func TestAccountManager_AddPeer(t *testing.T) { func TestAccountManager_AddPeer(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
if err != nil { if err != nil {
@@ -1305,11 +1413,13 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
group := group.Group{ err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
ID: "groupA", ID: "groupA",
Name: "GroupA", Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID}, Peers: []string{peer1.ID, peer2.ID, peer3.ID},
} })
require.NoError(t, err, "failed to save group")
policy := Policy{ policy := Policy{
Enabled: true, Enabled: true,
@@ -1352,7 +1462,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
return return
} }
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil {
t.Errorf("delete group: %v", err) t.Errorf("delete group: %v", err)
return return
} }
@@ -2606,7 +2716,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0) assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })
@@ -2626,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1) assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })
@@ -2665,7 +2775,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.syncJWTGroups(context.Background(), "accountID", claims) err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID") groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added") assert.Len(t, groups, 3, "new group3 should be added")

View File

@@ -148,6 +148,9 @@ const (
AccountPeerInactivityExpirationDurationUpdated Activity = 67 AccountPeerInactivityExpirationDurationUpdated Activity = 67
SetupKeyDeleted Activity = 68 SetupKeyDeleted Activity = 68
UserGroupPropagationEnabled Activity = 69
UserGroupPropagationDisabled Activity = 70
) )
var activityMap = map[Activity]Code{ var activityMap = map[Activity]Code{
@@ -222,6 +225,9 @@ var activityMap = map[Activity]Code{
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"}, AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"}, AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"}, SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"},
UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"},
UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"},
} }
// StringCode returns a string code of the activity // StringCode returns a string code of the activity

View File

@@ -146,7 +146,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
} }
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil

View File

@@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
// It is recommended to call it with locking FileStore.mux // It is recommended to call it with locking FileStore.mux
func (s *FileStore) persist(ctx context.Context, file string) error { func (s *FileStore) persist(ctx context.Context, file string) error {
start := time.Now() start := time.Now()
err := util.WriteJson(file, s) err := util.WriteJson(context.Background(), file, s)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -6,11 +6,12 @@ import (
"fmt" "fmt"
"slices" "slices"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@@ -27,18 +28,17 @@ func (e *GroupLinkError) Error() string {
// CheckGroupPermissions validates if a user has the necessary permissions to view groups // CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID { if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "groups are blocked for users") return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
} }
return nil return nil
@@ -49,8 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
} }
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
@@ -58,13 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
return am.Store.GetAccountGroups(ctx, accountID)
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
} }
// SaveGroup object of the peers // SaveGroup object of the peers
@@ -77,79 +75,74 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
// SaveGroups adds new groups to the account. // SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock. // Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method. // It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var eventsToStore []func() var eventsToStore []func()
var groupsToSave []*nbgroup.Group
var updateAccountPeers bool
for _, newGroup := range newGroups { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { groupIDs := make([]string, 0, len(groups))
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) for _, newGroup := range groups {
} if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err return err
} }
}
// Avoid duplicate groups only for the API issued groups. newGroup.AccountID = accountID
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. groupsToSave = append(groupsToSave, newGroup)
if existingGroup != nil { groupIDs = append(groupIDs, newGroup.ID)
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String() events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
}
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
oldGroup := account.Groups[newGroup.ID]
account.Groups[newGroup.ID] = newGroup
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
eventsToStore = append(eventsToStore, events...) eventsToStore = append(eventsToStore, events...)
} }
newGroupIDs := make([]string, 0, len(newGroups)) updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
for _, newGroup := range newGroups { if err != nil {
newGroupIDs = append(newGroupIDs, newGroup.ID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
if areGroupChangesAffectPeers(account, newGroupIDs) { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
am.updateAccountPeers(ctx, account) return err
}
return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
})
if err != nil {
return err
} }
for _, storeEvent := range eventsToStore { for _, storeEvent := range eventsToStore {
storeEvent() storeEvent()
} }
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil return nil
} }
// prepareGroupEvents prepares a list of event functions to be stored. // prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() { func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
var eventsToStore []func() var eventsToStore []func()
addedPeers := make([]string, 0) addedPeers := make([]string, 0)
removedPeers := make([]string, 0) removedPeers := make([]string, 0)
if oldGroup != nil { oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
if err == nil && oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers) addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else { } else {
@@ -159,35 +152,42 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
}) })
} }
for _, p := range addedPeers { modifiedPeers := slices.Concat(addedPeers, removedPeers)
peer := account.Peers[p] peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers)
if peer == nil { if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
return nil
}
for _, peerID := range addedPeers {
peer, ok := peers[peerID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID)
continue continue
} }
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, meta := map[string]any{
map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID,
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), }
}) am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta)
}) })
} }
for _, p := range removedPeers { for _, peerID := range removedPeers {
peer := account.Peers[p] peer, ok := peers[peerID]
if peer == nil { if !ok {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID)
continue continue
} }
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, meta := map[string]any{
map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID,
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), }
}) am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta)
}) })
} }
@@ -210,42 +210,10 @@ func difference(a, b []string) []string {
} }
// DeleteGroup object of the peers. // DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return nil
}
allGroup, err := account.GetGroupAll()
if err != nil {
return err
}
if allGroup.ID == groupID {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if err = validateDeleteGroup(account, group, userId); err != nil {
return err
}
delete(account.Groups, groupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
return nil
} }
// DeleteGroups deletes groups from an account. // DeleteGroups deletes groups from an account.
@@ -254,93 +222,94 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
// //
// If an error occurs while deleting a group, the function skips it and continues deleting other groups. // If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end. // Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error { func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountId) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var allErrors error var allErrors error
var groupIDsToDelete []string
var deletedGroups []*nbgroup.Group
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs)) err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
group, ok := account.Groups[groupID] group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
if !ok { if err != nil {
allErrors = errors.Join(allErrors, err)
continue continue
} }
if err := validateDeleteGroup(account, group, userId); err != nil { if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) allErrors = errors.Join(allErrors, err)
continue continue
} }
delete(account.Groups, groupID) groupIDsToDelete = append(groupIDsToDelete, groupID)
deletedGroups = append(deletedGroups, group) deletedGroups = append(deletedGroups, group)
} }
account.Network.IncSerial() if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
for _, g := range deletedGroups { return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) })
if err != nil {
return err
}
for _, group := range deletedGroups {
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
} }
return allErrors return allErrors
} }
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
}
// GroupAddPeer appends peer to the group // GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) var group *nbgroup.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID] if updated := group.AddPeer(peerID); !updated {
if !ok { return nil
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
} }
add := true updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
for _, itemID := range group.Peers { if err != nil {
if itemID == peerID {
add = false
break
}
}
if add {
group.Peers = append(group.Peers, peerID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
if areGroupChangesAffectPeers(account, []string{group.ID}) { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
am.updateAccountPeers(ctx, account) return err
}
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
})
if err != nil {
return err
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
@@ -351,90 +320,162 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) var group *nbgroup.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID] if updated := group.RemovePeer(peerID); !updated {
if !ok { return nil
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
} }
account.Network.IncSerial() updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
for i, itemID := range group.Peers { if err != nil {
if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
} }
if areGroupChangesAffectPeers(account, []string{group.ID}) { return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
am.updateAccountPeers(ctx, account) })
if err != nil {
return err
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
} }
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error { // validateNewGroup validates the new group for existence and required fields.
func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
if err != nil {
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
return err
}
}
// Prevent duplicate groups for API-issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
return nil
}
func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user // disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration { if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser := account.Users[userID] executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
if executingUser == nil { if err != nil {
return status.Errorf(status.NotFound, "user not found") return err
} }
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
} }
} }
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked { if group.IsGroupAll() {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)} return &GroupLinkError{"route", string(linkedRoute.NetID)}
} }
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked { if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name} return &GroupLinkError{"name server groups", linkedDns.Name}
} }
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked { if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name} return &GroupLinkError{"policy", linkedPolicy.Name}
} }
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked { if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name} return &GroupLinkError{"setup key", linkedSetupKey.Name}
} }
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked { if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id} return &GroupLinkError{"user", linkedUser.Id}
} }
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) { return checkGroupLinkedToSettings(ctx, transaction, group)
}
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name} return &GroupLinkError{"disabled DNS management groups", group.Name}
} }
if account.Settings.Extra != nil { settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) { if err != nil {
return &GroupLinkError{"integrated validator", group.Name} return err
} }
if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
} }
return nil return nil
} }
// isGroupLinkedToRoute checks if a group is linked to any route in the account. // isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
}
for _, r := range routes { for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r return true, r
} }
} }
return false, nil return false, nil
} }
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. // isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
}
for _, policy := range policies { for _, policy := range policies {
for _, rule := range policy.Rules { for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
@@ -446,7 +487,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
} }
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
}
for _, dns := range nameServerGroups { for _, dns := range nameServerGroups {
for _, g := range dns.Groups { for _, g := range dns.Groups {
if g == groupID { if g == groupID {
@@ -454,11 +501,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
} }
} }
} }
return false, nil return false, nil
} }
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) { func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
}
for _, setupKey := range setupKeys { for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) { if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey return true, setupKey
@@ -468,7 +522,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
} }
// isGroupLinkedToUser checks if a group is linked to any user in the account. // isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
}
for _, user := range users { for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) { if slices.Contains(user.AutoGroups, groupID) {
return true, user return true, user
@@ -477,6 +537,35 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
return false, nil return false, nil
} }
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return false, err
}
for _, groupID := range groupIDs {
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
return true, nil
}
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
return true, nil
}
}
return false, nil
}
// anyGroupHasPeers checks if any of the given groups in the account have peers. // anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool { func anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
@@ -486,22 +575,3 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
} }
return false return false
} }
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
return true
}
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
return true
}
}
return false
}

View File

@@ -49,3 +49,35 @@ func (g *Group) Copy() *Group {
func (g *Group) HasPeers() bool { func (g *Group) HasPeers() bool {
return len(g.Peers) > 0 return len(g.Peers) > 0
} }
// IsGroupAll checks if the group is a default "All" group.
func (g *Group) IsGroupAll() bool {
return g.Name == "All"
}
// AddPeer adds peerID to Peers if not present, returning true if added.
func (g *Group) AddPeer(peerID string) bool {
if peerID == "" {
return false
}
for _, itemID := range g.Peers {
if itemID == peerID {
return false
}
}
g.Peers = append(g.Peers, peerID)
return true
}
// RemovePeer removes peerID from Peers if present, returning true if removed.
func (g *Group) RemovePeer(peerID string) bool {
for i, itemID := range g.Peers {
if itemID == peerID {
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
return true
}
}
return false
}

View File

@@ -0,0 +1,90 @@
package group
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAddPeer(t *testing.T) {
t.Run("add new peer to empty slice", func(t *testing.T) {
group := &Group{Peers: []string{}}
peerID := "peer1"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add new peer to nil slice", func(t *testing.T) {
group := &Group{Peers: nil}
peerID := "peer1"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add new peer to non-empty slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer3"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add duplicate peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer1"
assert.False(t, group.AddPeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
t.Run("add empty peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := ""
assert.False(t, group.AddPeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
}
func TestRemovePeer(t *testing.T) {
t.Run("remove existing peer from slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2", "peer3"}}
peerID := "peer2"
assert.True(t, group.RemovePeer(peerID))
assert.NotContains(t, group.Peers, peerID)
assert.Equal(t, 2, len(group.Peers))
})
t.Run("remove peer from empty slice", func(t *testing.T) {
group := &Group{Peers: []string{}}
peerID := "peer1"
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 0, len(group.Peers))
})
t.Run("remove peer from nil slice", func(t *testing.T) {
group := &Group{Peers: nil}
peerID := "peer1"
assert.False(t, group.RemovePeer(peerID))
assert.Nil(t, group.Peers)
})
t.Run("remove non-existent peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer3"
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
t.Run("remove peer from single-item slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1"}}
peerID := "peer1"
assert.True(t, group.RemovePeer(peerID))
assert.Equal(t, 0, len(group.Peers))
assert.NotContains(t, group.Peers, peerID)
})
t.Run("remove empty peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := ""
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
}

View File

@@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
{ {
name: "delete non-existent group", name: "delete non-existent group",
groupIDs: []string{"non-existent-group"}, groupIDs: []string{"non-existent-group"},
expectedDeleted: []string{"non-existent-group"}, expectedReasons: []string{"group: non-existent-group not found"},
}, },
{ {
name: "delete multiple groups with mixed results", name: "delete multiple groups with mixed results",

View File

@@ -6,6 +6,7 @@ import (
"net" "net"
"net/netip" "net/netip"
"strings" "strings"
"sync"
"time" "time"
pb "github.com/golang/protobuf/proto" // nolint pb "github.com/golang/protobuf/proto" // nolint
@@ -38,6 +39,7 @@ type GRPCServer struct {
jwtClaimsExtractor *jwtclaims.ClaimsExtractor jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager ephemeralManager *EphemeralManager
peerLocks sync.Map
} }
// NewServer creates a new Management server // NewServer creates a new Management server
@@ -148,6 +150,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// nolint:staticcheck // nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
defer func() {
if unlock != nil {
unlock()
}
}()
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil { if err != nil {
// nolint:staticcheck // nolint:staticcheck
@@ -171,6 +180,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP) peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
return mapError(ctx, err) return mapError(ctx, err)
} }
@@ -190,11 +200,15 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
} }
unlock()
unlock = nil
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
} }
// handleUpdates sends updates to the connected peer until the updates channel is closed. // handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
for { for {
select { select {
// condition when there are some updates // condition when there are some updates
@@ -245,10 +259,18 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
} }
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock()
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
if err != nil {
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
}
s.peersUpdateManager.CloseChannel(ctx, peer.ID) s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.secretsManager.CancelRefresh(peer.ID) s.secretsManager.CancelRefresh(peer.ID)
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
s.ephemeralManager.OnPeerDisconnected(ctx, peer) s.ephemeralManager.OnPeerDisconnected(ctx, peer)
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
} }
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) { func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
@@ -274,6 +296,24 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
return claims.UserId, nil return claims.UserId, nil
} }
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
start := time.Now()
value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.Lock()
log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start))
start = time.Now()
unlock = func() {
mtx.Unlock()
log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start))
}
return unlock
}
// maps internal internalStatus.Error to gRPC status.Error // maps internal internalStatus.Error to gRPC status.Error
func mapError(ctx context.Context, err error) error { func mapError(ctx context.Context, err error) error {
if e, ok := internalStatus.FromError(err); ok { if e, ok := internalStatus.FromError(err); ok {

View File

@@ -439,17 +439,13 @@ components:
example: 5 example: 5
required: required:
- accessible_peers_count - accessible_peers_count
SetupKey: SetupKeyBase:
type: object type: object
properties: properties:
id: id:
description: Setup Key ID description: Setup Key ID
type: string type: string
example: 2531583362 example: 2531583362
key:
description: Setup Key value
type: string
example: A616097E-FCF0-48FA-9354-CA4A61142761
name: name:
description: Setup key name identifier description: Setup key name identifier
type: string type: string
@@ -518,22 +514,31 @@ components:
- updated_at - updated_at
- usage_limit - usage_limit
- ephemeral - ephemeral
SetupKeyClear:
allOf:
- $ref: '#/components/schemas/SetupKeyBase'
- type: object
properties:
key:
description: Setup Key as plain text
type: string
example: A616097E-FCF0-48FA-9354-CA4A61142761
required:
- key
SetupKey:
allOf:
- $ref: '#/components/schemas/SetupKeyBase'
- type: object
properties:
key:
description: Setup Key as secret
type: string
example: A6160****
required:
- key
SetupKeyRequest: SetupKeyRequest:
type: object type: object
properties: properties:
name:
description: Setup Key name
type: string
example: Default key
type:
description: Setup key type, one-off for single time usage and reusable
type: string
example: reusable
expires_in:
description: Expiration time in seconds, 0 will mean the key never expires
type: integer
minimum: 0
example: 86400
revoked: revoked:
description: Setup key revocation status description: Setup key revocation status
type: boolean type: boolean
@@ -544,21 +549,9 @@ components:
items: items:
type: string type: string
example: "ch8i4ug6lnn4g9hqv7m0" example: "ch8i4ug6lnn4g9hqv7m0"
usage_limit:
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
type: integer
example: 0
ephemeral:
description: Indicate that the peer will be ephemeral or not
type: boolean
example: true
required: required:
- name
- type
- expires_in
- revoked - revoked
- auto_groups - auto_groups
- usage_limit
CreateSetupKeyRequest: CreateSetupKeyRequest:
type: object type: object
properties: properties:
@@ -1943,7 +1936,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/SetupKey' $ref: '#/components/schemas/SetupKeyClear'
'400': '400':
"$ref": "#/components/responses/bad_request" "$ref": "#/components/responses/bad_request"
'401': '401':

View File

@@ -1062,7 +1062,94 @@ type SetupKey struct {
// Id Setup Key ID // Id Setup Key ID
Id string `json:"id"` Id string `json:"id"`
// Key Setup Key value // Key Setup Key as secret
Key string `json:"key"`
// LastUsed Setup key last usage date
LastUsed time.Time `json:"last_used"`
// Name Setup key name identifier
Name string `json:"name"`
// Revoked Setup key revocation status
Revoked bool `json:"revoked"`
// State Setup key status, "valid", "overused","expired" or "revoked"
State string `json:"state"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UpdatedAt Setup key last update date
UpdatedAt time.Time `json:"updated_at"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
// UsedTimes Usage count of setup key
UsedTimes int `json:"used_times"`
// Valid Setup key validity status
Valid bool `json:"valid"`
}
// SetupKeyBase defines model for SetupKeyBase.
type SetupKeyBase struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral bool `json:"ephemeral"`
// Expires Setup Key expiration date
Expires time.Time `json:"expires"`
// Id Setup Key ID
Id string `json:"id"`
// LastUsed Setup key last usage date
LastUsed time.Time `json:"last_used"`
// Name Setup key name identifier
Name string `json:"name"`
// Revoked Setup key revocation status
Revoked bool `json:"revoked"`
// State Setup key status, "valid", "overused","expired" or "revoked"
State string `json:"state"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UpdatedAt Setup key last update date
UpdatedAt time.Time `json:"updated_at"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
// UsedTimes Usage count of setup key
UsedTimes int `json:"used_times"`
// Valid Setup key validity status
Valid bool `json:"valid"`
}
// SetupKeyClear defines model for SetupKeyClear.
type SetupKeyClear struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral bool `json:"ephemeral"`
// Expires Setup Key expiration date
Expires time.Time `json:"expires"`
// Id Setup Key ID
Id string `json:"id"`
// Key Setup Key as plain text
Key string `json:"key"` Key string `json:"key"`
// LastUsed Setup key last usage date // LastUsed Setup key last usage date
@@ -1098,23 +1185,8 @@ type SetupKeyRequest struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key // AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral *bool `json:"ephemeral,omitempty"`
// ExpiresIn Expiration time in seconds, 0 will mean the key never expires
ExpiresIn int `json:"expires_in"`
// Name Setup Key name
Name string `json:"name"`
// Revoked Setup key revocation status // Revoked Setup key revocation status
Revoked bool `json:"revoked"` Revoked bool `json:"revoked"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
} }
// User defines model for User. // User defines model for User.

View File

@@ -184,14 +184,26 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
respBody := make([]*api.PeerBatch, 0, len(account.Peers)) peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
for _, peer := range account.Peers { if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupsMap := map[string]*nbgroup.Group{}
groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
for _, group := range groups {
groupsMap[group.ID] = group
}
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer) peerToReturn, err := h.checkPeerStatus(peer)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
} }
@@ -304,7 +316,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
} }
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum groupsInfo := []api.GroupMinimum{}
groupsChecked := make(map[string]struct{}) groupsChecked := make(map[string]struct{})
for _, group := range groups { for _, group := range groups {
_, ok := groupsChecked[group.ID] _, ok := groupsChecked[group.ID]

View File

@@ -149,7 +149,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
} }
if req.Peer == nil && req.PeerGroups == nil { if req.Peer == nil && req.PeerGroups == nil {
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided") return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided")
} }
if req.Peer != nil && req.PeerGroups != nil { if req.Peer != nil && req.PeerGroups != nil {

View File

@@ -137,11 +137,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
return return
} }
if req.Name == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
return
}
if req.AutoGroups == nil { if req.AutoGroups == nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
return return
@@ -150,7 +145,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
newKey := &server.SetupKey{} newKey := &server.SetupKey{}
newKey.AutoGroups = req.AutoGroups newKey.AutoGroups = req.AutoGroups
newKey.Revoked = req.Revoked newKey.Revoked = req.Revoked
newKey.Name = req.Name
newKey.Id = keyID newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID) newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)

View File

@@ -52,26 +52,23 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
return am.Store.SaveAccount(ctx, a) return am.Store.SaveAccount(ctx, a)
} }
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) { func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
if len(groups) == 0 { if len(groupIDs) == 0 {
return true, nil return true, nil
} }
accountsGroups, err := am.ListGroups(ctx, accountId)
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
for _, groupID := range groupIDs {
_, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
}
return nil
})
if err != nil { if err != nil {
return false, err return false, err
} }
for _, group := range groups {
var found bool
for _, accountGroup := range accountsGroups {
if accountGroup.ID == group {
found = true
break
}
}
if !found {
return false, nil
}
}
return true, nil return true, nil
} }

View File

@@ -11,7 +11,7 @@ import (
// IntegratedValidator interface exists to avoid the circle dependencies // IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface { type IntegratedValidator interface {
ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error)

View File

@@ -453,8 +453,8 @@ func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtr
return nil return nil
} }
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, nil return update, false, nil
} }
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {

View File

@@ -45,7 +45,6 @@ type MockAccountManager struct {
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
@@ -354,14 +353,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented") return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
} }
// ListGroups mock implementation of ListGroups from server.AccountManager interface
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
if am.ListGroupsFunc != nil {
return am.ListGroupsFunc(ctx, accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
}
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
if am.GroupAddPeerFunc != nil { if am.GroupAddPeerFunc != nil {

View File

@@ -71,7 +71,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
} }
if anyGroupHasPeers(account, newNSGroup.Groups) { if anyGroupHasPeers(account, newNSGroup.Groups) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
@@ -106,7 +106,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
} }
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
@@ -136,7 +136,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
} }
if anyGroupHasPeers(account, nsGroup.Groups) { if anyGroupHasPeers(account, nsGroup.Groups) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())

View File

@@ -110,14 +110,16 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error { func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error {
peer, err := account.FindPeerByPubKey(peerPubKey) peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to find peer by pub key: %w", err)
} }
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account) expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to update peer status and location: %w", err)
} }
log.WithContext(ctx).Debugf("mark peer %s connected: %t", peer.ID, connected)
if peer.AddedWithSSOLogin() { if peer.AddedWithSSOLogin() {
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account) am.checkAndSchedulePeerLoginExpiration(ctx, account)
@@ -131,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
if expired { if expired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from // we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again. // the expired one. Here we notify them that connection is now allowed again.
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account.Id)
} }
return nil return nil
@@ -166,9 +168,11 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
account.UpdatePeer(peer) account.UpdatePeer(peer)
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
if err != nil { if err != nil {
return false, err return false, fmt.Errorf("failed to save peer status: %w", err)
} }
return oldStatus.LoginExpired, nil return oldStatus.LoginExpired, nil
@@ -189,7 +193,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID) return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID)
} }
update, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) var requiresPeerUpdates bool
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -265,8 +270,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return nil, err return nil, err
} }
if peerLabelUpdated { if peerLabelUpdated || requiresPeerUpdates {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return peer, nil return peer, nil
@@ -330,7 +335,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err return err
} }
updateAccountPeers := isPeerInActiveGroup(account, peerID) updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID)
if err != nil {
return err
}
err = am.deletePeers(ctx, account, []string{peerID}, userID) err = am.deletePeers(ctx, account, []string{peerID}, userID)
if err != nil { if err != nil {
@@ -343,7 +351,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
} }
if updateAccountPeers { if updateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
@@ -550,7 +558,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return fmt.Errorf("failed to add peer to account: %w", err) return fmt.Errorf("failed to add peer to account: %w", err)
} }
err = transaction.IncrementNetworkSerial(ctx, accountID) err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -586,11 +594,22 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil { if err != nil {
return nil, nil, nil, fmt.Errorf("error getting account: %w", err) return nil, nil, nil, status.NewGetAccountError(err)
} }
if areGroupChangesAffectPeers(account, groupsToAdd) { allGroup, err := account.GetGroupAll()
am.updateAccountPeers(ctx, account) if err != nil {
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
}
groupsToAdd = append(groupsToAdd, allGroup.ID)
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd)
if err != nil {
return nil, nil, nil, err
}
if newGroupsAffectsPeers {
am.updateAccountPeers(ctx, accountID)
} }
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(account)
@@ -633,7 +652,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
if peer.UserID != "" { if peer.UserID != "" {
user, err := account.FindUser(peer.UserID) user, err := account.FindUser(peer.UserID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, fmt.Errorf("failed to get user: %w", err)
} }
err = checkIfPeerOwnerIsBlocked(peer, user) err = checkIfPeerOwnerIsBlocked(peer, user)
@@ -648,19 +667,22 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
updated := peer.UpdateMetaIfNew(sync.Meta) updated := peer.UpdateMetaIfNew(sync.Meta)
if updated { if updated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
account.Peers[peer.ID] = peer
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
err = am.Store.SavePeer(ctx, account.Id, peer) err = am.Store.SavePeer(ctx, account.Id, peer)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err)
} }
if sync.UpdateAccountPeers { if sync.UpdateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account.Id)
} }
} }
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, fmt.Errorf("failed to validate peer: %w", err)
} }
var postureChecks []*posture.Checks var postureChecks []*posture.Checks
@@ -673,12 +695,12 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
} }
if isStatusChanged { if isStatusChanged {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account.Id)
} }
validPeersMap, err := am.GetValidatedPeers(account) validPeersMap, err := am.GetValidatedPeers(account)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
} }
postureChecks = am.getPeerPostureChecks(account, peer) postureChecks = am.getPeerPostureChecks(account, peer)
@@ -758,7 +780,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
} }
} }
groups, err := am.Store.GetAccountGroups(ctx, accountID) groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -780,6 +802,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
updated := peer.UpdateMetaIfNew(login.Meta) updated := peer.UpdateMetaIfNew(login.Meta)
if updated { if updated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
shouldStorePeer = true shouldStorePeer = true
} }
@@ -804,7 +827,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
} }
if updateRemotePeers || isStatusChanged { if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer)
@@ -967,7 +990,13 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
// updateAccountPeers updates all peers that belong to an account. // updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers. // Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) { func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err)
return
}
start := time.Now() start := time.Now()
defer func() { defer func() {
if am.metrics != nil { if am.metrics != nil {
@@ -1021,12 +1050,12 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
// IsPeerInActiveGroup checks if the given peer is part of a group that is used // IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration. // in an active DNS, route, or ACL configuration.
func isPeerInActiveGroup(account *Account, peerID string) bool { func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) {
peerGroupIDs := make([]string, 0) peerGroupIDs := make([]string, 0)
for _, group := range account.Groups { for _, group := range account.Groups {
if slices.Contains(group.Peers, peerID) { if slices.Contains(group.Peers, peerID) {
peerGroupIDs = append(peerGroupIDs, group.ID) peerGroupIDs = append(peerGroupIDs, group.ID)
} }
} }
return areGroupChangesAffectPeers(account, peerGroupIDs) return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs)
} }

View File

@@ -22,6 +22,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -876,7 +877,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
start := time.Now() start := time.Now()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.updateAccountPeers(ctx, account) manager.updateAccountPeers(ctx, account.Id)
} }
duration := time.Since(start) duration := time.Since(start)
@@ -1398,6 +1399,50 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
} }
}) })
t.Run("validator requires update", func(t *testing.T) {
requireUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *nbAccount.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, true, nil
}
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireUpdateFunc}
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
t.Run("validator requires no update", func(t *testing.T) {
requireNoUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *nbAccount.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, false, nil
}
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc}
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Adding peer to group linked with policy should update account peers and send peer update // Adding peer to group linked with policy should update account peers and send peer update
t.Run("adding peer to group linked with policy", func(t *testing.T) { t.Run("adding peer to group linked with policy", func(t *testing.T) {
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{

View File

@@ -377,7 +377,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers { if updateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
@@ -406,7 +406,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
if anyGroupHasPeers(account, policy.ruleGroups()) { if anyGroupHasPeers(account, policy.ruleGroups()) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil

View File

@@ -69,7 +69,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil

View File

@@ -238,7 +238,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
} }
if isRouteChangeAffectPeers(account, &newRoute) { if isRouteChangeAffectPeers(account, &newRoute) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
@@ -324,7 +324,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
} }
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
@@ -356,7 +356,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
if isRouteChangeAffectPeers(account, routy) { if isRouteChangeAffectPeers(account, routy) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil

View File

@@ -1091,7 +1091,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.ListGroups(context.Background(), account.Id) groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id)
require.NoError(t, err) require.NoError(t, err)
var groupHA1, groupHA2 *nbgroup.Group var groupHA1, groupHA2 *nbgroup.Group
for _, group := range groups { for _, group := range groups {

View File

@@ -4,8 +4,8 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
b64 "encoding/base64" b64 "encoding/base64"
"fmt"
"hash/fnv" "hash/fnv"
"slices"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -229,32 +229,43 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil { if user.AccountID != accountID {
return nil, err return nil, status.NewUserNotPartOfAccountError()
} }
setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) if user.IsRegularUser() {
account.SetupKeys[setupKey.Key] = setupKey return nil, status.NewAdminPermissionError()
err = am.Store.SaveAccount(ctx, account) }
var setupKey *SetupKey
var plainKey string
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil {
return err
}
setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
setupKey.AccountID = accountID
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey)
eventsToStore = append(eventsToStore, events...)
return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey)
})
if err != nil { if err != nil {
return nil, status.Errorf(status.Internal, "failed adding account key") return nil, err
} }
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta())
for _, storeEvent := range eventsToStore {
for _, g := range setupKey.AutoGroups { storeEvent()
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
}
} }
// for the creation return the plain key to the caller // for the creation return the plain key to the caller
@@ -266,45 +277,61 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
// SaveSetupKey saves the provided SetupKey to the database overriding the existing one. // SaveSetupKey saves the provided SetupKey to the database overriding the existing one.
// Due to the unique nature of a SetupKey certain properties must not be overwritten // Due to the unique nature of a SetupKey certain properties must not be overwritten
// (e.g. the key itself, creation date, ID, etc). // (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. // These properties are overwritten: AutoGroups, Revoked (only from false to true), and the UpdatedAt. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if keyToSave == nil { if keyToSave == nil {
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
} }
account, err := am.Store.GetAccount(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
var oldKey *SetupKey var oldKey *SetupKey
for _, key := range account.SetupKeys { var newKey *SetupKey
if key.Id == keyToSave.Id { var eventsToStore []func()
oldKey = key.Copy()
break err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
} if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil {
} return err
if oldKey == nil {
return nil, status.Errorf(status.NotFound, "setup key not found")
} }
if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil { oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id)
return nil, err if err != nil {
return err
} }
// only auto groups, revoked status, and name can be updated for now if oldKey.Revoked && !keyToSave.Revoked {
newKey := oldKey.Copy() return status.Errorf(status.InvalidArgument, "can't un-revoke a revoked setup key")
newKey.Name = keyToSave.Name }
// only auto groups, revoked status (from false to true) can be updated
newKey = oldKey.Copy()
newKey.AutoGroups = keyToSave.AutoGroups newKey.AutoGroups = keyToSave.AutoGroups
newKey.Revoked = keyToSave.Revoked newKey.Revoked = keyToSave.Revoked
newKey.UpdatedAt = time.Now().UTC() newKey.UpdatedAt = time.Now().UTC()
account.SetupKeys[newKey.Key] = newKey addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups)
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups)
if err = am.Store.SaveAccount(ctx, account); err != nil { events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey)
eventsToStore = append(eventsToStore, events...)
return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey)
})
if err != nil {
return nil, err return nil, err
} }
@@ -312,31 +339,10 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
am.StoreEvent(ctx, userID, newKey.Id, accountID, activity.SetupKeyRevoked, newKey.EventMeta()) am.StoreEvent(ctx, userID, newKey.Id, accountID, activity.SetupKeyRevoked, newKey.EventMeta())
} }
defer func() { for _, storeEvent := range eventsToStore {
addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) storeEvent()
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups)
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
} }
}
for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
}
}
}()
return newKey, nil return newKey, nil
} }
@@ -347,16 +353,15 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.NewUnauthorizedToViewSetupKeysError() return nil, status.NewUserNotPartOfAccountError()
} }
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) if user.IsRegularUser() {
if err != nil { return nil, status.NewAdminPermissionError()
return nil, err
} }
return setupKeys, nil return am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
} }
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
@@ -366,11 +371,15 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.NewUnauthorizedToViewSetupKeysError() return nil, status.NewUserNotPartOfAccountError()
} }
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -387,21 +396,29 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user: %w", err) return err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return status.NewUnauthorizedToViewSetupKeysError() return status.NewUserNotPartOfAccountError()
} }
deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var deletedSetupKey *SetupKey
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get setup key: %w", err) return err
} }
err = am.Store.DeleteSetupKey(ctx, accountID, keyID) return transaction.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID)
})
if err != nil { if err != nil {
return fmt.Errorf("failed to delete setup key: %w", err) return err
} }
am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta()) am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta())
@@ -409,15 +426,62 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
return nil return nil
} }
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error { func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error {
for _, group := range autoGroups { groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs)
g, ok := account.Groups[group] if err != nil {
return err
}
for _, groupID := range autoGroupIDs {
group, ok := groups[groupID]
if !ok { if !ok {
return status.Errorf(status.NotFound, "group %s doesn't exist", group) return status.Errorf(status.NotFound, "group not found: %s", groupID)
} }
if g.Name == "All" {
return status.Errorf(status.InvalidArgument, "can't add All group to the setup key") if group.IsGroupAll() {
return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key")
} }
} }
return nil return nil
} }
// prepareSetupKeyEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() {
var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
return nil
}
for _, g := range removedGroups {
group, ok := groups[g]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g)
continue
}
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta)
})
}
for _, g := range addedGroups {
group, ok := groups[g]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g)
continue
}
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta)
})
}
return eventsToStore
}

View File

@@ -56,11 +56,9 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
} }
autoGroups := []string{"group_1", "group_2"} autoGroups := []string{"group_1", "group_2"}
newKeyName := "my-new-test-key"
revoked := true revoked := true
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
Id: key.Id, Id: key.Id,
Name: newKeyName,
Revoked: revoked, Revoked: revoked,
AutoGroups: autoGroups, AutoGroups: autoGroups,
}, userID) }, userID)
@@ -68,7 +66,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt, assertKey(t, newKey, keyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
key.Id, time.Now().UTC(), autoGroups, true) key.Id, time.Now().UTC(), autoGroups, true)
// check the corresponding events that should have been generated // check the corresponding events that should have been generated
@@ -76,7 +74,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
assert.NotNil(t, ev) assert.NotNil(t, ev)
assert.Equal(t, account.Id, ev.AccountID) assert.Equal(t, account.Id, ev.AccountID)
assert.Equal(t, newKeyName, ev.Meta["name"]) assert.Equal(t, keyName, ev.Meta["name"])
assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"])) assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"]))
assert.NotEmpty(t, ev.Meta["key"]) assert.NotEmpty(t, ev.Meta["key"])
assert.Equal(t, userID, ev.InitiatorID) assert.Equal(t, userID, ev.InitiatorID)
@@ -89,7 +87,6 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
autoGroups = append(autoGroups, groupAll.ID) autoGroups = append(autoGroups, groupAll.ID)
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ _, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
Id: key.Id, Id: key.Id,
Name: newKeyName,
Revoked: revoked, Revoked: revoked,
AutoGroups: autoGroups, AutoGroups: autoGroups,
}, userID) }, userID)
@@ -213,22 +210,41 @@ func TestGetSetupKeys(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
ID: "group_1",
Name: "group_name_1",
Peers: []string{},
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ type testCase struct {
ID: "group_2", name string
Name: "group_name_2", keyId string
Peers: []string{}, expectedFailure bool
}
testCase1 := testCase{
name: "Should get existing Setup Key",
keyId: plainKey.Id,
expectedFailure: false,
}
testCase2 := testCase{
name: "Should fail to get non-existent Setup Key",
keyId: "some key",
expectedFailure: true,
}
for _, tCase := range []testCase{testCase1, testCase2} {
t.Run(tCase.name, func(t *testing.T) {
key, err := manager.GetSetupKey(context.Background(), account.Id, userID, tCase.keyId)
if tCase.expectedFailure {
if err == nil {
t.Fatal("expected to fail")
}
return
}
assert.NotEqual(t, plainKey.Key, key.Key)
}) })
if err != nil {
t.Fatal(err)
} }
} }
@@ -449,3 +465,31 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
} }
}) })
} }
func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
if err != nil {
t.Fatal(err)
}
key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
assert.NoError(t, err)
// revoke the key
updateKey := key.Copy()
updateKey.Revoked = true
_, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID)
assert.NoError(t, err)
// re-activate revoked key
updateKey.Revoked = false
_, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID)
assert.Error(t, err, "should not allow to update revoked key")
}

File diff suppressed because it is too large Load Diff

View File

@@ -14,11 +14,10 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
route2 "github.com/netbirdio/netbird/route" route2 "github.com/netbirdio/netbird/route"
@@ -1181,7 +1180,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
t.Fatal("failed to save group") t.Fatal("failed to save group")
return err return err
} }
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID) group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID)
if err != nil { if err != nil {
t.Fatal("failed to get group") t.Fatal("failed to get group")
return err return err
@@ -1201,7 +1200,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
account, err := store.GetAccount(context.Background(), accountID) account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err) require.NoError(t, err)
users, err := store.GetAccountUsers(context.Background(), accountID) users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, users, len(account.Users)) require.Len(t, users, len(account.Users))
} }
@@ -1260,9 +1259,9 @@ func TestSqlite_GetGroupByName(t *testing.T) {
} }
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID) group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "All", group.Name) require.True(t, group.IsGroupAll())
} }
func Test_DeleteSetupKeySuccessfully(t *testing.T) { func Test_DeleteSetupKeySuccessfully(t *testing.T) {
@@ -1274,7 +1273,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID) err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, setupKeyID)
require.NoError(t, err) require.NoError(t, err)
_, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID) _, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID)
@@ -1290,6 +1289,278 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
nonExistingKeyID := "non-existing-key-id" nonExistingKeyID := "non-existing-key-id"
err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID) err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID)
require.Error(t, err) require.Error(t, err)
} }
func TestSqlStore_GetGroupsByIDs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
groupIDs []string
expectedCount int
}{
{
name: "retrieve existing groups by existing IDs",
groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
expectedCount: 2,
},
{
name: "empty group IDs list",
groupIDs: []string{},
expectedCount: 0,
},
{
name: "non-existing group IDs",
groupIDs: []string{"nonexistent1", "nonexistent2"},
expectedCount: 0,
},
{
name: "mixed existing and non-existing group IDs",
groupIDs: []string{"cfefqs706sqkneg59g4g", "nonexistent"},
expectedCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs)
require.NoError(t, err)
require.Len(t, groups, tt.expectedCount)
})
}
}
func TestSqlStore_SaveGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group := &nbgroup.Group{
ID: "group-id",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
}
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
require.NoError(t, err)
savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id")
require.NoError(t, err)
require.Equal(t, savedGroup, group)
}
func TestSqlStore_SaveGroups(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
groups := []*nbgroup.Group{
{
ID: "group-1",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
},
{
ID: "group-2",
AccountID: accountID,
Issued: "integration",
Peers: []string{"peer3", "peer4"},
},
}
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups)
require.NoError(t, err)
}
func TestSqlStore_DeleteGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
groupID string
expectError bool
}{
{
name: "delete existing group",
groupID: "cfefqs706sqkneg59g4g",
expectError: false,
},
{
name: "delete non-existing group",
groupID: "non-existing-group-id",
expectError: true,
},
{
name: "delete with empty group ID",
groupID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
} else {
require.NoError(t, err)
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID)
require.Error(t, err)
require.Nil(t, group)
}
})
}
}
func TestSqlStore_DeleteGroups(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
groupIDs []string
expectError bool
}{
{
name: "delete multiple existing groups",
groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
expectError: false,
},
{
name: "delete non-existing groups",
groupIDs: []string{"non-existing-id-1", "non-existing-id-2"},
expectError: false,
},
{
name: "delete with empty group IDs list",
groupIDs: []string{},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs)
if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
for _, groupID := range tt.groupIDs {
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
require.Error(t, err)
require.Nil(t, group)
}
}
})
}
}
func TestSqlStore_GetPeerByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
peerID string
expectError bool
}{
{
name: "retrieve existing peer",
peerID: "cfefqs706sqkneg59g4g",
expectError: false,
},
{
name: "retrieve non-existing peer",
peerID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty peer ID",
peerID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, peer)
} else {
require.NoError(t, err)
require.NotNil(t, peer)
require.Equal(t, tt.peerID, peer.ID)
}
})
}
}
func TestSqlStore_GetPeersByIDs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
peerIDs []string
expectedCount int
}{
{
name: "retrieve existing peers by existing IDs",
peerIDs: []string{"cfefqs706sqkneg59g4g", "cfeg6sf06sqkneg59g50"},
expectedCount: 2,
},
{
name: "empty peer IDs list",
peerIDs: []string{},
expectedCount: 0,
},
{
name: "non-existing peer IDs",
peerIDs: []string{"nonexistent1", "nonexistent2"},
expectedCount: 0,
},
{
name: "mixed existing and non-existing peer IDs",
peerIDs: []string{"cfeg6sf06sqkneg59g50", "nonexistent"},
expectedCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs)
require.NoError(t, err)
require.Len(t, peers, tt.expectedCount)
})
}
}

View File

@@ -3,7 +3,6 @@ package status
import ( import (
"errors" "errors"
"fmt" "fmt"
"time"
) )
const ( const (
@@ -103,22 +102,27 @@ func NewPeerLoginExpiredError() error {
} }
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key // NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
func NewSetupKeyNotFoundError(err error) error { func NewSetupKeyNotFoundError(setupKeyID string) error {
return Errorf(NotFound, "setup key not found: %s", err) return Errorf(NotFound, "setup key: %s not found", setupKeyID)
} }
func NewGetAccountFromStoreError(err error) error { func NewGetAccountFromStoreError(err error) error {
return Errorf(Internal, "issue getting account from store: %s", err) return Errorf(Internal, "issue getting account from store: %s", err)
} }
// NewUserNotPartOfAccountError creates a new Error with PermissionDenied type for a user not being part of an account
func NewUserNotPartOfAccountError() error {
return Errorf(PermissionDenied, "user is not part of this account")
}
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store // NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
func NewGetUserFromStoreError() error { func NewGetUserFromStoreError() error {
return Errorf(Internal, "issue getting user from store") return Errorf(Internal, "issue getting user from store")
} }
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context // NewAdminPermissionError creates a new Error with PermissionDenied type for actions requiring admin role.
func NewStoreContextCanceledError(duration time.Duration) error { func NewAdminPermissionError() error {
return Errorf(Internal, "store access: context canceled after %v", duration) return Errorf(PermissionDenied, "admin role required to perform this action")
} }
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key // NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
@@ -126,7 +130,12 @@ func NewInvalidKeyIDError() error {
return Errorf(InvalidArgument, "invalid key ID") return Errorf(InvalidArgument, "invalid key ID")
} }
// NewUnauthorizedToViewSetupKeysError creates a new Error with Unauthorized type for an issue getting a setup key // NewGetAccountError creates a new Error with Internal type for an issue getting account
func NewUnauthorizedToViewSetupKeysError() error { func NewGetAccountError(err error) error {
return Errorf(Unauthorized, "only users with admin power can view setup keys") return Errorf(Internal, "error getting account: %s", err)
}
// NewGroupNotFoundError creates a new Error with NotFound type for a missing group
func NewGroupNotFoundError(groupID string) error {
return Errorf(NotFound, "group: %s not found", groupID)
} }

View File

@@ -62,7 +62,7 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
@@ -70,11 +70,14 @@ type Store interface {
DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error DeleteTokenID2UserIDIndex(tokenID string) error
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
@@ -89,6 +92,8 @@ type Store interface {
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error
@@ -96,7 +101,9 @@ type Store interface {
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error)
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
@@ -105,7 +112,7 @@ type Store interface {
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, accountId string) error IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
GetInstallationID() string GetInstallationID() string
@@ -124,7 +131,6 @@ type Store interface {
// This is also a method of metrics.DataSource interface. // This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine GetStoreEngine() StoreEngine
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
DeleteSetupKey(ctx context.Context, accountID, keyID string) error
} }
type StoreEngine string type StoreEngine string

View File

@@ -13,6 +13,7 @@ type AccountManagerMetrics struct {
updateAccountPeersDurationMs metric.Float64Histogram updateAccountPeersDurationMs metric.Float64Histogram
getPeerNetworkMapDurationMs metric.Float64Histogram getPeerNetworkMapDurationMs metric.Float64Histogram
networkMapObjectCount metric.Int64Histogram networkMapObjectCount metric.Int64Histogram
peerMetaUpdateCount metric.Int64Counter
} }
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics // NewAccountManagerMetrics creates an instance of AccountManagerMetrics
@@ -44,11 +45,17 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
return nil, err return nil, err
} }
peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter", metric.WithUnit("1"))
if err != nil {
return nil, err
}
return &AccountManagerMetrics{ return &AccountManagerMetrics{
ctx: ctx, ctx: ctx,
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs, getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
updateAccountPeersDurationMs: updateAccountPeersDurationMs, updateAccountPeersDurationMs: updateAccountPeersDurationMs,
networkMapObjectCount: networkMapObjectCount, networkMapObjectCount: networkMapObjectCount,
peerMetaUpdateCount: peerMetaUpdateCount,
}, nil }, nil
} }
@@ -67,3 +74,8 @@ func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration ti
func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) { func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
metrics.networkMapObjectCount.Record(metrics.ctx, count) metrics.networkMapObjectCount.Record(metrics.ctx, count)
} }
// CountPeerMetUpdate counts the number of peer meta updates
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)
}

View File

@@ -13,6 +13,7 @@ type StoreMetrics struct {
globalLockAcquisitionDurationMs metric.Int64Histogram globalLockAcquisitionDurationMs metric.Int64Histogram
persistenceDurationMicro metric.Int64Histogram persistenceDurationMicro metric.Int64Histogram
persistenceDurationMs metric.Int64Histogram persistenceDurationMs metric.Int64Histogram
transactionDurationMs metric.Int64Histogram
ctx context.Context ctx context.Context
} }
@@ -40,11 +41,17 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er
return nil, err return nil, err
} }
transactionDurationMs, err := meter.Int64Histogram("management.store.transaction.duration.ms")
if err != nil {
return nil, err
}
return &StoreMetrics{ return &StoreMetrics{
globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro, globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro,
globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs, globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs,
persistenceDurationMicro: persistenceDurationMicro, persistenceDurationMicro: persistenceDurationMicro,
persistenceDurationMs: persistenceDurationMs, persistenceDurationMs: persistenceDurationMs,
transactionDurationMs: transactionDurationMs,
ctx: ctx, ctx: ctx,
}, nil }, nil
} }
@@ -60,3 +67,8 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) {
metrics.persistenceDurationMicro.Record(metrics.ctx, duration.Microseconds()) metrics.persistenceDurationMicro.Record(metrics.ctx, duration.Microseconds())
metrics.persistenceDurationMs.Record(metrics.ctx, duration.Milliseconds()) metrics.persistenceDurationMs.Record(metrics.ctx, duration.Milliseconds())
} }
// CountTransactionDuration counts the duration of a store persistence operation
func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) {
metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds())
}

View File

@@ -96,9 +96,12 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
if channel, ok := p.peerChannels[peerID]; ok { if channel, ok := p.peerChannels[peerID]; ok {
delete(p.peerChannels, peerID) delete(p.peerChannels, peerID)
close(channel) close(channel)
}
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
return
}
log.WithContext(ctx).Debugf("closing updates channel: peer %s has no channel", peerID)
} }
// CloseChannels closes updates channel for each given peer // CloseChannels closes updates channel for each given peer

View File

@@ -9,14 +9,16 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/integration_reference"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
) )
const ( const (
@@ -103,6 +105,11 @@ func (u *User) IsAdminOrServiceUser() bool {
return u.HasAdminPower() || u.IsServiceUser return u.HasAdminPower() || u.IsServiceUser
} }
// IsRegularUser checks if the user is a regular user.
func (u *User) IsRegularUser() bool {
return !u.HasAdminPower() && !u.IsServiceUser
}
// ToUserInfo converts a User object to a UserInfo object. // ToUserInfo converts a User object to a UserInfo object.
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
autoGroups := u.AutoGroups autoGroups := u.AutoGroups
@@ -487,7 +494,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
if updateAccountPeers { if updateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account.Id)
} }
return nil return nil
@@ -798,15 +805,20 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
expiredPeers = append(expiredPeers, blockedPeers...) expiredPeers = append(expiredPeers, blockedPeers...)
} }
peerGroupsAdded := make(map[string][]string)
peerGroupsRemoved := make(map[string][]string)
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled { if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
// need force update all auto groups in any case they will not be duplicated // need force update all auto groups in any case they will not be duplicated
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) peerGroupsAdded = account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) peerGroupsRemoved = account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
} }
events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole) userUpdateEvents := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
eventsToStore = append(eventsToStore, events...) eventsToStore = append(eventsToStore, userUpdateEvents...)
userGroupsEvents := am.prepareUserGroupsEvents(ctx, initiatorUser.Id, oldUser, newUser, account, peerGroupsAdded, peerGroupsRemoved)
eventsToStore = append(eventsToStore, userGroupsEvents...)
updatedUserInfo, err := getUserInfo(ctx, am, newUser, account) updatedUserInfo, err := getUserInfo(ctx, am, newUser, account)
if err != nil { if err != nil {
@@ -828,7 +840,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
} }
if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) { if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account.Id)
} }
for _, storeEvent := range eventsToStore { for _, storeEvent := range eventsToStore {
@@ -865,9 +877,53 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in
}) })
} }
return eventsToStore
}
func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() {
var eventsToStore []func()
if newUser.AutoGroups != nil { if newUser.AutoGroups != nil {
removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups) removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups)
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups) addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
removedEvents := am.handleGroupRemovedFromUser(ctx, initiatorUserID, oldUser, newUser, account, removedGroups, peerGroupsRemoved)
eventsToStore = append(eventsToStore, removedEvents...)
addedEvents := am.handleGroupAddedToUser(ctx, initiatorUserID, oldUser, newUser, account, addedGroups, peerGroupsAdded)
eventsToStore = append(eventsToStore, addedEvents...)
}
return eventsToStore
}
func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() {
var eventsToStore []func()
for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
})
}
}
for groupID, peerIDs := range peerGroupsAdded {
group := account.GetGroup(groupID)
for _, peerID := range peerIDs {
peer := account.GetPeer(peerID)
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
}
am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupAddedToPeer, meta)
})
}
}
return eventsToStore
}
func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() {
var eventsToStore []func()
for _, g := range removedGroups { for _, g := range removedGroups {
group := account.GetGroup(g) group := account.GetGroup(g)
if group != nil { if group != nil {
@@ -880,17 +936,19 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id) log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
} }
} }
for _, g := range addedGroups { for groupID, peerIDs := range peerGroupsRemoved {
group := account.GetGroup(g) group := account.GetGroup(groupID)
if group != nil { for _, peerID := range peerIDs {
peer := account.GetPeer(peerID)
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser, meta := map[string]any{
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) "group": group.Name, "group_id": group.ID,
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
}
am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupRemovedFromPeer, meta)
}) })
} }
} }
}
return eventsToStore return eventsToStore
} }
@@ -1100,6 +1158,9 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error { func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error {
var peerIDs []string var peerIDs []string
for _, peer := range peers { for _, peer := range peers {
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key)
if peer.Status.LoginExpired { if peer.Status.LoginExpired {
continue continue
} }
@@ -1107,8 +1168,11 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
peer.MarkLoginExpired(true) peer.MarkLoginExpired(true)
account.UpdatePeer(peer) account.UpdatePeer(peer)
if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil { if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil {
return err return fmt.Errorf("failed saving peer status for peer %s: %s", peer.ID, err)
} }
log.WithContext(ctx).Tracef("mark peer %s login expired", peer.ID)
am.StoreEvent( am.StoreEvent(
ctx, ctx,
peer.UserID, peer.ID, account.Id, peer.UserID, peer.ID, account.Id,
@@ -1119,7 +1183,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
if len(peerIDs) != 0 { if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service // this will trigger peer disconnect from the management service
am.peersUpdateManager.CloseChannels(ctx, peerIDs) am.peersUpdateManager.CloseChannels(ctx, peerIDs)
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account.Id)
} }
return nil return nil
} }
@@ -1227,7 +1291,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
} }
if updateAccountPeers { if updateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
for targetUserID, meta := range deletedUsersMeta { for targetUserID, meta := range deletedUsersMeta {

View File

@@ -3,7 +3,6 @@ package client
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"net" "net"
"sync" "sync"
"time" "time"
@@ -449,11 +448,11 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
conn, ok := c.conns[id] conn, ok := c.conns[id]
c.mu.Unlock() c.mu.Unlock()
if !ok { if !ok {
return 0, io.EOF return 0, net.ErrClosed
} }
if conn.conn != connReference { if conn.conn != connReference {
return 0, io.EOF return 0, net.ErrClosed
} }
// todo: use buffer pool instead of create new transport msg. // todo: use buffer pool instead of create new transport msg.
@@ -508,7 +507,7 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
container, ok := c.conns[id] container, ok := c.conns[id]
if !ok { if !ok {
return fmt.Errorf("connection already closed") return net.ErrClosed
} }
if container.conn != connReference { if container.conn != connReference {

View File

@@ -1,7 +1,6 @@
package client package client
import ( import (
"io"
"net" "net"
"time" "time"
) )
@@ -40,7 +39,7 @@ func (c *Conn) Write(p []byte) (n int, err error) {
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {
msg, ok := <-c.messageChan msg, ok := <-c.messageChan
if !ok { if !ok {
return 0, io.EOF return 0, net.ErrClosed
} }
n = copy(b, msg.Payload) n = copy(b, msg.Payload)

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"sync" "sync"
"time" "time"
@@ -100,7 +99,7 @@ func (c *Conn) isClosed() bool {
func (c *Conn) ioErrHandling(err error) error { func (c *Conn) ioErrHandling(err error) error {
if c.isClosed() { if c.isClosed() {
return io.EOF return net.ErrClosed
} }
var wErr *websocket.CloseError var wErr *websocket.CloseError
@@ -108,7 +107,7 @@ func (c *Conn) ioErrHandling(err error) error {
return err return err
} }
if wErr.Code == websocket.StatusNormalClosure { if wErr.Code == websocket.StatusNormalClosure {
return io.EOF return net.ErrClosed
} }
return err return err
} }

View File

@@ -63,13 +63,14 @@ func (l *Listener) Shutdown(ctx context.Context) error {
} }
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
connRemoteAddr := remoteAddr(r)
wsConn, err := websocket.Accept(w, r, nil) wsConn, err := websocket.Accept(w, r, nil)
if err != nil { if err != nil {
log.Errorf("failed to accept ws connection from %s: %s", r.RemoteAddr, err) log.Errorf("failed to accept ws connection from %s: %s", connRemoteAddr, err)
return return
} }
rAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr) rAddr, err := net.ResolveTCPAddr("tcp", connRemoteAddr)
if err != nil { if err != nil {
err = wsConn.Close(websocket.StatusInternalError, "internal error") err = wsConn.Close(websocket.StatusInternalError, "internal error")
if err != nil { if err != nil {
@@ -90,3 +91,10 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
conn := NewConn(wsConn, lAddr, rAddr) conn := NewConn(wsConn, lAddr, rAddr)
l.acceptFn(conn) l.acceptFn(conn)
} }
func remoteAddr(r *http.Request) string {
if r.Header.Get("X-Real-Ip") == "" || r.Header.Get("X-Real-Port") == "" {
return r.RemoteAddr
}
return fmt.Sprintf("%s:%s", r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port"))
}

View File

@@ -2,7 +2,7 @@ package server
import ( import (
"context" "context"
"io" "errors"
"net" "net"
"sync" "sync"
"time" "time"
@@ -16,6 +16,8 @@ import (
const ( const (
bufferSize = 8820 bufferSize = 8820
errCloseConn = "failed to close connection to peer: %s"
) )
// Peer represents a peer connection // Peer represents a peer connection
@@ -46,6 +48,12 @@ func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle // It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
// the message accordingly. // the message accordingly.
func (p *Peer) Work() { func (p *Peer) Work() {
defer func() {
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
p.log.Errorf(errCloseConn, err)
}
}()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@@ -57,7 +65,7 @@ func (p *Peer) Work() {
for { for {
n, err := p.conn.Read(buf) n, err := p.conn.Read(buf)
if err != nil { if err != nil {
if err != io.EOF { if !errors.Is(err, net.ErrClosed) {
p.log.Errorf("failed to read message: %s", err) p.log.Errorf("failed to read message: %s", err)
} }
return return
@@ -97,7 +105,7 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
case messages.MsgTypeClose: case messages.MsgTypeClose:
p.log.Infof("peer exited gracefully") p.log.Infof("peer exited gracefully")
if err := p.conn.Close(); err != nil { if err := p.conn.Close(); err != nil {
log.Errorf("failed to close connection to peer: %s", err) log.Errorf(errCloseConn, err)
} }
default: default:
p.log.Warnf("received unexpected message type: %s", msgType) p.log.Warnf("received unexpected message type: %s", msgType)
@@ -121,9 +129,8 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
p.log.Errorf("failed to send close message to peer: %s", p.String()) p.log.Errorf("failed to send close message to peer: %s", p.String())
} }
err = p.conn.Close() if err := p.conn.Close(); err != nil {
if err != nil { p.log.Errorf(errCloseConn, err)
p.log.Errorf("failed to close connection to peer: %s", err)
} }
} }
@@ -132,7 +139,7 @@ func (p *Peer) Close() {
defer p.connMu.Unlock() defer p.connMu.Unlock()
if err := p.conn.Close(); err != nil { if err := p.conn.Close(); err != nil {
p.log.Errorf("failed to close connection to peer: %s", err) p.log.Errorf(errCloseConn, err)
} }
} }

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@@ -14,8 +15,21 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
func WriteBytesWithRestrictedPermission(ctx context.Context, file string, bs []byte) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return fmt.Errorf("prepare config file dir: %w", err)
}
if err = EnforcePermission(file); err != nil {
return fmt.Errorf("enforce permission: %w", err)
}
return writeBytes(ctx, file, err, configDir, configFileName, bs)
}
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory // WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
func WriteJsonWithRestrictedPermission(file string, obj interface{}) error { func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file) configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil { if err != nil {
return err return err
@@ -26,18 +40,18 @@ func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
return err return err
} }
return writeJson(file, obj, configDir, configFileName) return writeJson(ctx, file, obj, configDir, configFileName)
} }
// WriteJson writes JSON config object to a file creating parent directories if required // WriteJson writes JSON config object to a file creating parent directories if required
// The output JSON is pretty-formatted // The output JSON is pretty-formatted
func WriteJson(file string, obj interface{}) error { func WriteJson(ctx context.Context, file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file) configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil { if err != nil {
return err return err
} }
return writeJson(file, obj, configDir, configFileName) return writeJson(ctx, file, obj, configDir, configFileName)
} }
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file // DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
@@ -79,24 +93,47 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
return nil return nil
} }
func writeJson(file string, obj interface{}, configDir string, configFileName string) error { func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error {
// Check context before expensive operations
if ctx.Err() != nil {
return fmt.Errorf("write json start: %w", ctx.Err())
}
// make it pretty // make it pretty
bs, err := json.MarshalIndent(obj, "", " ") bs, err := json.MarshalIndent(obj, "", " ")
if err != nil { if err != nil {
return err return fmt.Errorf("marshal: %w", err)
}
return writeBytes(ctx, file, err, configDir, configFileName, bs)
}
func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error {
if ctx.Err() != nil {
return fmt.Errorf("write bytes start: %w", ctx.Err())
} }
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
if err != nil { if err != nil {
return err return fmt.Errorf("create temp: %w", err)
} }
tempFileName := tempFile.Name() tempFileName := tempFile.Name()
// closing file ops as windows doesn't allow to move it
err = tempFile.Close() if deadline, ok := ctx.Deadline(); ok {
if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) {
log.Warnf("failed to set deadline: %v", err)
}
}
_, err = tempFile.Write(bs)
if err != nil { if err != nil {
return err _ = tempFile.Close()
return fmt.Errorf("write: %w", err)
}
if err = tempFile.Close(); err != nil {
return fmt.Errorf("close %s: %w", tempFileName, err)
} }
defer func() { defer func() {
@@ -106,14 +143,13 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st
} }
}() }()
err = os.WriteFile(tempFileName, bs, 0600) // Check context again
if err != nil { if ctx.Err() != nil {
return err return fmt.Errorf("after temp file: %w", ctx.Err())
} }
err = os.Rename(tempFileName, file) if err = os.Rename(tempFileName, file); err != nil {
if err != nil { return fmt.Errorf("move %s to %s: %w", tempFileName, file, err)
return err
} }
return nil return nil

View File

@@ -1,6 +1,7 @@
package util package util
import ( import (
"context"
"crypto/md5" "crypto/md5"
"encoding/hex" "encoding/hex"
"io" "io"
@@ -39,7 +40,7 @@ func TestConfigJSON(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
err := WriteJson(tmpDir+"/testconfig.json", tt.config) err := WriteJson(context.Background(), tmpDir+"/testconfig.json", tt.config)
require.NoError(t, err) require.NoError(t, err)
read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
@@ -73,7 +74,7 @@ func TestCopyFileContents(t *testing.T) {
src := tmpDir + "/copytest_src" src := tmpDir + "/copytest_src"
dst := tmpDir + "/copytest_dst" dst := tmpDir + "/copytest_dst"
err := WriteJson(src, tt.srcContent) err := WriteJson(context.Background(), src, tt.srcContent)
require.NoError(t, err) require.NoError(t, err)
err = CopyFileContents(src, dst) err = CopyFileContents(src, dst)
@@ -127,7 +128,7 @@ func TestHandleConfigFileWithoutFullPath(t *testing.T) {
_ = os.Remove(cfgFile) _ = os.Remove(cfgFile)
}() }()
err := WriteJson(cfgFile, tt.config) err := WriteJson(context.Background(), cfgFile, tt.config)
require.NoError(t, err) require.NoError(t, err)
read, err := ReadJson(cfgFile, &TestConfig{}) read, err := ReadJson(cfgFile, &TestConfig{})

View File

@@ -3,6 +3,9 @@ package grpc
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net" "net"
"os/user" "os/user"
"runtime" "runtime"
@@ -23,20 +26,22 @@ func WithCustomDialer() grpc.DialOption {
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
currentUser, err := user.Current() currentUser, err := user.Current()
if err != nil { if err != nil {
log.Fatalf("failed to get current user: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
} }
// the custom dialer requires root permissions which are not required for use cases run as non-root // the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" { if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{} dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr) return dialer.DialContext(ctx, "tcp", addr)
} }
} }
log.Debug("Using nbnet.NewDialer()")
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
log.Errorf("Failed to dial: %s", err) log.Errorf("Failed to dial: %s", err)
return nil, err return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
} }
return conn, nil return conn, nil
}) })

View File

@@ -69,7 +69,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
conn, err := d.Dialer.DialContext(ctx, network, address) conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil { if err != nil {
return nil, fmt.Errorf("dial: %w", err) return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
} }
// Wrap the connection in Conn to handle Close with hooks // Wrap the connection in Conn to handle Close with hooks

View File

@@ -12,7 +12,10 @@ import (
const ( const (
// NetbirdFwmark is the fwmark value used by Netbird via wireguard // NetbirdFwmark is the fwmark value used by Netbird via wireguard
NetbirdFwmark = 0x1BD00 NetbirdFwmark = 0x1BD00
PreroutingFwmark = 0x1BD01
PreroutingFwmarkRedirected = 0x1BD01
PreroutingFwmarkMasquerade = 0x1BD11
PreroutingFwmarkMasqueradeReturn = 0x1BD12
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
) )

View File

@@ -4,9 +4,14 @@ package net
import ( import (
"fmt" "fmt"
"os"
"syscall" "syscall"
log "github.com/sirupsen/logrus"
) )
const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK"
// SetSocketMark sets the SO_MARK option on the given socket connection // SetSocketMark sets the SO_MARK option on the given socket connection
func SetSocketMark(conn syscall.Conn) error { func SetSocketMark(conn syscall.Conn) error {
sysconn, err := conn.SyscallConn() sysconn, err := conn.SyscallConn()
@@ -36,6 +41,13 @@ func SetRawSocketMark(conn syscall.RawConn) error {
func SetSocketOpt(fd int) error { func SetSocketOpt(fd int) error {
if CustomRoutingDisabled() { if CustomRoutingDisabled() {
log.Infof("Custom routing is disabled, skipping SO_MARK")
return nil
}
// Check for the new environment variable
if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" {
log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK")
return nil return nil
} }