Compare commits

...

26 Commits

Author SHA1 Message Date
Viktor Liu
3efa7a282a Set log level debug 2024-11-27 13:46:37 +01:00
Viktor Liu
40551099b3 Add debug 2024-11-27 13:41:32 +01:00
Pascal Fischer
9db1932664 [management] Fix getSetupKey call (#2927) 2024-11-22 10:15:51 +01:00
Viktor Liu
1bbabf70b0 [client] Fix allow netbird rule verdict (#2925)
* Fix allow netbird rule verdict

* Fix chain name
2024-11-21 16:53:37 +01:00
Pascal Fischer
aa575d6f44 [management] Add activity events to group propagation flow (#2916) 2024-11-21 15:10:34 +01:00
Pascal Fischer
f66bbcc54c [management] Add metric for peer meta update (#2913) 2024-11-19 18:13:26 +01:00
Pascal Fischer
5dd6a08ea6 link peer meta update back to account object (#2911) 2024-11-19 17:25:49 +01:00
Krzysztof Nazarewski (kdn)
eb5d0569ae [client] Add NB_SKIP_SOCKET_MARK & fix crash instead of returing an error (#2899)
* dialer: fix crash instead of returning error

* add NB_SKIP_SOCKET_MARK
2024-11-19 14:14:58 +01:00
Pascal Fischer
52ea2e84e9 [management] Add transaction metrics and exclude getAccount time from peers update (#2904) 2024-11-19 00:04:50 +01:00
Maycon Santos
78fab877c0 [misc] Update signing pipeline version (#2900) 2024-11-18 15:31:53 +01:00
Maycon Santos
65a94f695f use google domain for tests (#2902) 2024-11-18 12:55:02 +01:00
Kursat Aktas
ec543f89fb Introducing NetBird Guru on Gurubase.io (#2778) 2024-11-16 15:45:31 +01:00
Viktor Liu
a7d5c52203 Fix error state race on mgmt connection error (#2892) 2024-11-15 22:59:49 +01:00
Viktor Liu
582bb58714 Move state updates outside the refcounter (#2897) 2024-11-15 22:55:33 +01:00
Viktor Liu
121dfda915 [client] Fix state manager race conditions (#2890) 2024-11-15 20:05:26 +01:00
İsmail
a1c5287b7c Fix the Inactivity Expiration problem. (#2865) 2024-11-15 18:21:27 +01:00
Bethuel Mmbaga
12f442439a [management] Refactor group to use store methods (#2867)
* Refactor setup key handling to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add lock to get account groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add check for regular user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get only required groups for auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add account lock and return auto groups map on validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor account peers update

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor groups to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor GetGroupByID and add NewGroupNotFoundError

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add AddPeer and RemovePeer methods to Group struct

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Preserve store engine in SqlStore transactions

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run groups ops in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix missing group removed from setup key activity

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Change setup key log level to debug for missing group

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve modified peers once for group events

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add account locking and merge group deletion methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-15 20:09:32 +03:00
Pascal Fischer
d9b691b8a5 [management] Limit the setup-key update operation (#2841) 2024-11-15 17:00:06 +01:00
Pascal Fischer
4aee3c9e33 [client/management] add peer lock to peer meta update and fix isEqual func (#2840) 2024-11-15 16:59:03 +01:00
Pascal Fischer
44e799c687 [management] Fix limited peer view groups (#2894) 2024-11-15 11:16:16 +01:00
Viktor Liu
be78efbd42 [client] Handle panic on nil wg interface (#2891) 2024-11-14 20:15:16 +01:00
Maycon Santos
6886691213 Update route calculation tests (#2884)
- Add two new test cases for p2p and relay routes with same latency
- Add extra statuses generation
2024-11-13 15:21:33 +01:00
Zoltan Papp
b48afd92fd [relay-server] Always close ws conn when work thread exit (#2879)
Close ws conn when work thread exit
2024-11-13 15:02:51 +01:00
Viktor Liu
39329e12a1 [client] Improve state write timeout and abort work early on timeout (#2882)
* Improve state write timeout and abort work early on timeout

* Don't block on initial persist state
2024-11-13 13:46:00 +01:00
Pascal Fischer
20a5afc359 [management] Add more logs to the peer update processes (#2881) 2024-11-12 14:19:22 +01:00
Bethuel Mmbaga
6cb697eed6 [management] Refactor setup key to use store methods (#2861)
* Refactor setup key handling to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add lock to get account groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add check for regular user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get only required groups for auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add account lock and return auto groups map on validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix missing group removed from setup key activity

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Remove context from DB queries

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add user permission check and add setup events into events to store slice

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve all groups once during setup key auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-11 19:46:10 +03:00
58 changed files with 2026 additions and 776 deletions

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.16" SIGN_PIPE_VER: "v0.0.17"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"

View File

@@ -19,6 +19,10 @@
<br> <br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ"> <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
</a> </a>
</p> </p>
</div> </div>

View File

@@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
} }
// persist early to ensure cleanup of chains // persist early to ensure cleanup of chains
if err := stateManager.PersistState(context.Background()); err != nil { go func() {
log.Errorf("failed to persist state: %v", err) if err := stateManager.PersistState(context.Background()); err != nil {
} log.Errorf("failed to persist state: %v", err)
}
}()
return nil return nil
} }

View File

@@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
} }
// persist early // persist early
if err := stateManager.PersistState(context.Background()); err != nil { go func() {
log.Errorf("failed to persist state: %v", err) if err := stateManager.PersistState(context.Background()); err != nil {
} log.Errorf("failed to persist state: %v", err)
}
}()
return nil return nil
} }
@@ -197,7 +199,7 @@ func (m *Manager) AllowNetbird() error {
var chain *nftables.Chain var chain *nftables.Chain
for _, c := range chains { for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameForward { if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
chain = c chain = c
break break
} }
@@ -274,7 +276,7 @@ func (m *Manager) resetNetbirdInputRules() error {
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains { for _, c := range chains {
if c.Table.Name == "filter" && c.Name == "INPUT" { if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
rules, err := m.rConn.GetRules(c.Table, c) rules, err := m.rConn.GetRules(c.Table, c)
if err != nil { if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err) log.Errorf("get rules for chain %q: %v", c.Name, err)
@@ -349,7 +351,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
Register: 1, Register: 1,
Data: ifname(m.wgIface.Name()), Data: ifname(m.wgIface.Name()),
}, },
&expr.Verdict{}, &expr.Verdict{
Kind: expr.VerdictAccept,
},
}, },
UserData: []byte(allowNetbirdInputRuleID), UserData: []byte(allowNetbirdInputRuleID),
} }

View File

@@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg) err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
return cfg, err return cfg, err
} }
@@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
// WriteOutConfig write put the prepared config to the given path // WriteOutConfig write put the prepared config to the given path
func WriteOutConfig(path string, config *Config) error { func WriteOutConfig(path string, config *Config) error {
return util.WriteJson(path, config) return util.WriteJson(context.Background(), path, config)
} }
// createNewConfig creates a new config generating a new Wireguard key and saving to file // createNewConfig creates a new config generating a new Wireguard key and saving to file
@@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
} }
if updated { if updated {
if err := util.WriteJson(input.ConfigPath, config); err != nil { if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
return nil, err return nil, err
} }
} }

View File

@@ -157,7 +157,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
engineCtx, cancel := context.WithCancel(c.ctx) engineCtx, cancel := context.WithCancel(c.ctx)
defer func() { defer func() {
c.statusRecorder.MarkManagementDisconnected(state.err) _, err := state.Status()
c.statusRecorder.MarkManagementDisconnected(err)
c.statusRecorder.CleanLocalPeerState() c.statusRecorder.CleanLocalPeerState()
cancel() cancel()
}() }()

View File

@@ -7,7 +7,6 @@ import (
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2" "github.com/mitchellh/hashstructure/v2"
@@ -323,12 +322,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
log.Error(err) log.Error(err)
} }
// persist dns state right away go func() {
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) // persist dns state right away
defer cancel() if err := s.stateManager.PersistState(s.ctx); err != nil {
if err := s.stateManager.PersistState(ctx); err != nil { log.Errorf("Failed to persist dns state: %v", err)
log.Errorf("Failed to persist dns state: %v", err) }
} }()
if s.searchDomainNotifier != nil { if s.searchDomainNotifier != nil {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
@@ -533,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks(
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
} }
// persist dns state right away go func() {
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) if err := s.stateManager.PersistState(s.ctx); err != nil {
defer cancel() l.Errorf("Failed to persist dns state: %v", err)
if err := s.stateManager.PersistState(ctx); err != nil { }
l.Errorf("Failed to persist dns state: %v", err) }()
}
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone() s.addHostRootZone()

View File

@@ -782,7 +782,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
Port: 53, Port: 53,
}, },
}, },
Domains: []string{"customdomain.com"}, Domains: []string{"google.com"},
Primary: false, Primary: false,
}, },
}, },
@@ -804,7 +804,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
if ips[0] != zoneRecords[0].RData { if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err) t.Fatalf("invalid zone record: %v", err)
} }
_, err = resolver.LookupHost(context.Background(), "customdomain.com") _, err = resolver.LookupHost(context.Background(), "google.com")
if err != nil { if err != nil {
t.Errorf("failed to resolve: %s", err) t.Errorf("failed to resolve: %s", err)
} }

View File

@@ -11,6 +11,7 @@ import (
"reflect" "reflect"
"runtime" "runtime"
"slices" "slices"
"sort"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -38,7 +39,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@@ -171,7 +171,7 @@ type Engine struct {
relayManager *relayClient.Manager relayManager *relayClient.Manager
stateManager *statemanager.Manager stateManager *statemanager.Manager
srWatcher *guard.SRWatcher srWatcher *guard.SRWatcher
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -297,7 +297,7 @@ func (e *Engine) Stop() error {
if err := e.stateManager.Stop(ctx); err != nil { if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err) return fmt.Errorf("failed to stop state manager: %w", err)
} }
if err := e.stateManager.PersistState(ctx); err != nil { if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
@@ -641,6 +641,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} }
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if e.wgInterface == nil {
return errors.New("wireguard interface is not initialized")
}
if e.wgInterface.Address().String() != conf.Address { if e.wgInterface.Address().String() != conf.Address {
oldAddr := e.wgInterface.Address().String() oldAddr := e.wgInterface.Address().String()
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address) log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
@@ -1481,6 +1485,17 @@ func (e *Engine) stopDNSServer() {
// isChecksEqual checks if two slices of checks are equal. // isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
for _, check := range checks {
sort.Slice(check.Files, func(i, j int) bool {
return check.Files[i] < check.Files[j]
})
}
for _, oCheck := range oChecks {
sort.Slice(oCheck.Files, func(i, j int) bool {
return oCheck.Files[i] < oCheck.Files[j]
})
}
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool { return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files) return slices.Equal(checks.Files, oChecks.Files)
}) })

View File

@@ -1006,6 +1006,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
} }
} }
func Test_CheckFilesEqual(t *testing.T) {
testCases := []struct {
name string
inputChecks1 []*mgmtProto.Checks
inputChecks2 []*mgmtProto.Checks
expectedBool bool
}{
{
name: "Equal Files In Equal Order Should Return True",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
expectedBool: true,
},
{
name: "Equal Files In Reverse Order Should Return True",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile2",
"testfile1",
},
},
},
expectedBool: true,
},
{
name: "Unequal Files Should Return False",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile3",
},
},
},
expectedBool: false,
},
{
name: "Compared With Empty Should Return False",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{},
},
},
expectedBool: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
})
}
}
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) { func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {

View File

@@ -1,6 +1,7 @@
package routemanager package routemanager
import ( import (
"fmt"
"net/netip" "net/netip"
"testing" "testing"
"time" "time"
@@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
currentRoute: "route1", currentRoute: "route1",
expectedRouteID: "route1", expectedRouteID: "route1",
}, },
{
name: "relayed routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: true,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
relayed: true,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "p2p routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{ {
name: "current route with bad score should be changed to route with better score", name: "current route with bad score should be changed to route with better score",
statuses: map[route.ID]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
@@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
} }
// fill the test data with random routes
for _, tc := range testCases {
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
dummyStatus := routerPeerStatus{
connected: false,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
dummyStatus := routerPeerStatus{
connected: false,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
currentRoute := &route.Route{ currentRoute := &route.Route{

View File

@@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
type Counter[Key comparable, I, O any] struct { type Counter[Key comparable, I, O any] struct {
// refCountMap keeps track of the reference Ref for keys // refCountMap keeps track of the reference Ref for keys
refCountMap map[Key]Ref[O] refCountMap map[Key]Ref[O]
refCountMu sync.Mutex mu sync.Mutex
// idMap keeps track of the keys associated with an ID for removal // idMap keeps track of the keys associated with an ID for removal
idMap map[string][]Key idMap map[string][]Key
idMu sync.Mutex
add AddFunc[Key, I, O] add AddFunc[Key, I, O]
remove RemoveFunc[Key, O] remove RemoveFunc[Key, O]
} }
@@ -75,10 +74,8 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
func (rm *Counter[Key, I, O]) LoadData( func (rm *Counter[Key, I, O]) LoadData(
existingCounter *Counter[Key, I, O], existingCounter *Counter[Key, I, O],
) { ) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.refCountMap = existingCounter.refCountMap rm.refCountMap = existingCounter.refCountMap
rm.idMap = existingCounter.idMap rm.idMap = existingCounter.idMap
@@ -87,8 +84,8 @@ func (rm *Counter[Key, I, O]) LoadData(
// Get retrieves the current reference count and associated data for a key. // Get retrieves the current reference count and associated data for a key.
// If the key doesn't exist, it returns a zero value Ref and false. // If the key doesn't exist, it returns a zero value Ref and false.
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
ref, ok := rm.refCountMap[key] ref, ok := rm.refCountMap[key]
return ref, ok return ref, ok
@@ -97,9 +94,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
// Increment increments the reference count for the given key. // Increment increments the reference count for the given key.
// If this is the first reference to the key, the AddFunc is called. // If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
return rm.increment(key, in)
}
func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
ref := rm.refCountMap[key] ref := rm.refCountMap[key]
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out) logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
@@ -126,10 +127,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
// IncrementWithID increments the reference count for the given key and groups it under the given ID. // IncrementWithID increments the reference count for the given key and groups it under the given ID.
// If this is the first reference to the key, the AddFunc is called. // If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) { func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
rm.idMu.Lock() rm.mu.Lock()
defer rm.idMu.Unlock() defer rm.mu.Unlock()
ref, err := rm.Increment(key, in) ref, err := rm.increment(key, in)
if err != nil { if err != nil {
return ref, fmt.Errorf("with ID: %w", err) return ref, fmt.Errorf("with ID: %w", err)
} }
@@ -141,9 +142,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O],
// Decrement decrements the reference count for the given key. // Decrement decrements the reference count for the given key.
// If the reference count reaches 0, the RemoveFunc is called. // If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
return rm.decrement(key)
}
func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
ref, ok := rm.refCountMap[key] ref, ok := rm.refCountMap[key]
if !ok { if !ok {
logCallerF("No reference found for key %v", key) logCallerF("No reference found for key %v", key)
@@ -168,12 +172,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
// DecrementWithID decrements the reference count for all keys associated with the given ID. // DecrementWithID decrements the reference count for all keys associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called. // If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
rm.idMu.Lock() rm.mu.Lock()
defer rm.idMu.Unlock() defer rm.mu.Unlock()
var merr *multierror.Error var merr *multierror.Error
for _, key := range rm.idMap[id] { for _, key := range rm.idMap[id] {
if _, err := rm.Decrement(key); err != nil { if _, err := rm.decrement(key); err != nil {
merr = multierror.Append(merr, err) merr = multierror.Append(merr, err)
} }
} }
@@ -184,10 +188,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
// Flush removes all references and calls RemoveFunc for each key. // Flush removes all references and calls RemoveFunc for each key.
func (rm *Counter[Key, I, O]) Flush() error { func (rm *Counter[Key, I, O]) Flush() error {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error var merr *multierror.Error
for key := range rm.refCountMap { for key := range rm.refCountMap {
@@ -206,10 +208,8 @@ func (rm *Counter[Key, I, O]) Flush() error {
// Clear removes all references without calling RemoveFunc. // Clear removes all references without calling RemoveFunc.
func (rm *Counter[Key, I, O]) Clear() { func (rm *Counter[Key, I, O]) Clear() {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
clear(rm.refCountMap) clear(rm.refCountMap)
clear(rm.idMap) clear(rm.idMap)
@@ -217,10 +217,8 @@ func (rm *Counter[Key, I, O]) Clear() {
// MarshalJSON implements the json.Marshaler interface for Counter. // MarshalJSON implements the json.Marshaler interface for Counter.
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
return json.Marshal(struct { return json.Marshal(struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"` RefCountMap map[Key]Ref[O] `json:"refCountMap"`

View File

@@ -2,31 +2,28 @@ package systemops
import ( import (
"net/netip" "net/netip"
"sync"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
) )
type ShutdownState struct { type ShutdownState ExclusionCounter
Counter *ExclusionCounter `json:"counter,omitempty"`
mu sync.RWMutex
}
func (s *ShutdownState) Name() string { func (s *ShutdownState) Name() string {
return "route_state" return "route_state"
} }
func (s *ShutdownState) Cleanup() error { func (s *ShutdownState) Cleanup() error {
s.mu.RLock()
defer s.mu.RUnlock()
if s.Counter == nil {
return nil
}
sysops := NewSysOps(nil, nil) sysops := NewSysOps(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
sysops.refCounter.LoadData(s.Counter) sysops.refCounter.LoadData((*ExclusionCounter)(s))
return sysops.refCounter.Flush() return sysops.refCounter.Flush()
} }
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
return (*ExclusionCounter)(s).MarshalJSON()
}
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
return (*ExclusionCounter)(s).UnmarshalJSON(data)
}

View File

@@ -57,30 +57,19 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return nexthop, refcounter.ErrIgnore return nexthop, refcounter.ErrIgnore
} }
r.updateState(stateManager)
return nexthop, err return nexthop, err
}, },
func(prefix netip.Prefix, nexthop Nexthop) error { r.removeFromRouteTable,
// remove from state even if we have trouble removing it from the route table
// it could be already gone
r.updateState(stateManager)
return r.removeFromRouteTable(prefix, nexthop)
},
) )
r.refCounter = refCounter r.refCounter = refCounter
return r.setupHooks(initAddresses) return r.setupHooks(initAddresses, stateManager)
} }
// updateState updates state on every change so it will be persisted regularly
func (r *SysOps) updateState(stateManager *statemanager.Manager) { func (r *SysOps) updateState(stateManager *statemanager.Manager) {
state := getState(stateManager) if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil {
state.Counter = r.refCounter
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
} }
} }
@@ -336,7 +325,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
return r.removeFromRouteTable(prefix, nextHop) return r.removeFromRouteTable(prefix, nextHop)
} }
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip) prefix, err := util.GetPrefixFromIP(ip)
if err != nil { if err != nil {
@@ -347,6 +336,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
return fmt.Errorf("adding route reference: %v", err) return fmt.Errorf("adding route reference: %v", err)
} }
r.updateState(stateManager)
return nil return nil
} }
afterHook := func(connID nbnet.ConnectionID) error { afterHook := func(connID nbnet.ConnectionID) error {
@@ -354,6 +345,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
return fmt.Errorf("remove route reference: %w", err) return fmt.Errorf("remove route reference: %w", err)
} }
r.updateState(stateManager)
return nil return nil
} }
@@ -532,14 +525,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
// Return true if the longest matching prefix is from vpnRoutes // Return true if the longest matching prefix is from vpnRoutes
return isVpn, longestPrefix return isVpn, longestPrefix
} }
func getState(stateManager *statemanager.Manager) *ShutdownState {
var shutdownState *ShutdownState
if state := stateManager.GetState(shutdownState); state != nil {
shutdownState = state.(*ShutdownState)
} else {
shutdownState = &ShutdownState{}
}
return shutdownState
}

View File

@@ -55,7 +55,7 @@ type ruleParams struct {
// isLegacy determines whether to use the legacy routing setup // isLegacy determines whether to use the legacy routing setup
func isLegacy() bool { func isLegacy() bool {
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true"
} }
// setIsLegacy sets the legacy routing setup // setIsLegacy sets the legacy routing setup

View File

@@ -16,6 +16,7 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/util"
) )
// State interface defines the methods that all state types must implement // State interface defines the methods that all state types must implement
@@ -73,15 +74,15 @@ func (m *Manager) Stop(ctx context.Context) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if m.cancel != nil { if m.cancel == nil {
m.cancel() return nil
}
m.cancel()
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-m.done: case <-m.done:
return nil
}
} }
return nil return nil
@@ -178,25 +179,18 @@ func (m *Manager) PersistState(ctx context.Context) error {
return nil return nil
} }
ctx, cancel := context.WithTimeout(ctx, 3*time.Second) bs, err := marshalWithPanicRecovery(m.states)
if err != nil {
return fmt.Errorf("marshal states: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel() defer cancel()
done := make(chan error, 1) done := make(chan error, 1)
start := time.Now()
go func() { go func() {
data, err := json.MarshalIndent(m.states, "", " ") done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
if err != nil {
done <- fmt.Errorf("marshal states: %w", err)
return
}
// nolint:gosec
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
done <- fmt.Errorf("write state file: %w", err)
return
}
done <- nil
}() }()
select { select {
@@ -208,7 +202,7 @@ func (m *Manager) PersistState(ctx context.Context) error {
} }
} }
log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty)) log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
clear(m.dirty) clear(m.dirty)
@@ -296,3 +290,19 @@ func (m *Manager) PerformCleanup() error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
func marshalWithPanicRecovery(v any) ([]byte, error) {
var bs []byte
var err error
func() {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during marshal: %v", r)
}
}()
bs, err = json.Marshal(v)
}()
return bs, err
}

View File

@@ -4,32 +4,20 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
log "github.com/sirupsen/logrus"
) )
// GetDefaultStatePath returns the path to the state file based on the operating system // GetDefaultStatePath returns the path to the state file based on the operating system
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist. // It returns an empty string if the path cannot be determined.
func GetDefaultStatePath() string { func GetDefaultStatePath() string {
var path string
switch runtime.GOOS { switch runtime.GOOS {
case "windows": case "windows":
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
case "darwin", "linux": case "darwin", "linux":
path = "/var/lib/netbird/state.json" return "/var/lib/netbird/state.json"
case "freebsd", "openbsd", "netbsd", "dragonfly": case "freebsd", "openbsd", "netbsd", "dragonfly":
path = "/var/db/netbird/state.json" return "/var/db/netbird/state.json"
// ios/android don't need state
default:
return ""
} }
dir := filepath.Dir(path) return ""
if err := os.MkdirAll(dir, 0755); err != nil {
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
return ""
}
return path
} }

View File

@@ -1,13 +1,41 @@
package main package main
import ( import (
"net/http"
_ "net/http/pprof"
"os" "os"
"path/filepath"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/cmd" "github.com/netbirdio/netbird/client/cmd"
) )
func main() { func main() {
// only if env is not set
if os.Getenv("NB_LOG_LEVEL") == "" {
if err := os.Setenv("NB_LOG_LEVEL", "debug"); err != nil {
log.Errorf("Failed setting log-level: %v", err)
}
}
if err := os.Setenv("NB_LOG_MAX_SIZE_MB", "100"); err != nil {
log.Errorf("Failed setting log-size: %v", err)
}
if err := os.Setenv("NB_WINDOWS_PANIC_LOG", filepath.Join(os.Getenv("ProgramData"), "netbird", "netbird.err")); err != nil {
log.Errorf("Failed setting panic log path: %v", err)
}
go startPprofServer()
if err := cmd.Execute(); err != nil { if err := cmd.Execute(); err != nil {
os.Exit(1) os.Exit(1)
} }
} }
func startPprofServer() {
pprofAddr := "localhost:6969"
log.Infof("Starting pprof debugging server on %s", pprofAddr)
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
log.Infof("pprof server failed: %v", err)
}
}

View File

@@ -120,6 +120,35 @@ func (s *Server) Start() error {
ctx, cancel := context.WithCancel(s.rootCtx) ctx, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel s.actCancel = cancel
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if statusResp, err := s.Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true}); err != nil {
log.Infof("Error getting status: %v", err)
} else if statusResp.FullStatus != nil {
log.Infof("Status --------")
for _, peer := range statusResp.FullStatus.Peers {
log.Infof("[Peer Connection] Name: %s, IP: %s, Key: %s, Connection Status: %s, Relayed: %v, RelayedAddress: %v, Last WireGuard Handshake: %v",
peer.Fqdn,
peer.IP,
peer.PubKey,
peer.ConnStatus,
peer.Relayed,
peer.RelayAddress,
peer.LastWireguardHandshake.AsTime().Format("15:04:05"),
)
}
}
}
}
}()
// if configuration exists, we just start connections. if is new config we skip and set status NeedsLogin // if configuration exists, we just start connections. if is new config we skip and set status NeedsLogin
// on failure we return error to retry // on failure we return error to retry
config, err := internal.UpdateConfig(s.latestConfigInput) config, err := internal.UpdateConfig(s.latestConfigInput)

View File

@@ -110,7 +110,6 @@ type AccountManager interface {
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
@@ -966,7 +965,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro
} }
// UserGroupsAddToPeers adds groups to all peers of user // UserGroupsAddToPeers adds groups to all peers of user
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) { func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string {
groupUpdates := make(map[string][]string)
userPeers := make(map[string]struct{}) userPeers := make(map[string]struct{})
for pid, peer := range a.Peers { for pid, peer := range a.Peers {
if peer.UserID == userID { if peer.UserID == userID {
@@ -980,6 +981,8 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
continue continue
} }
oldPeers := group.Peers
groupPeers := make(map[string]struct{}) groupPeers := make(map[string]struct{})
for _, pid := range group.Peers { for _, pid := range group.Peers {
groupPeers[pid] = struct{}{} groupPeers[pid] = struct{}{}
@@ -993,16 +996,25 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
for pid := range groupPeers { for pid := range groupPeers {
group.Peers = append(group.Peers, pid) group.Peers = append(group.Peers, pid)
} }
groupUpdates[gid] = difference(group.Peers, oldPeers)
} }
return groupUpdates
} }
// UserGroupsRemoveFromPeers removes groups from all peers of user // UserGroupsRemoveFromPeers removes groups from all peers of user
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string {
groupUpdates := make(map[string][]string)
for _, gid := range groups { for _, gid := range groups {
group, ok := a.Groups[gid] group, ok := a.Groups[gid]
if !ok || group.Name == "All" { if !ok || group.Name == "All" {
continue continue
} }
oldPeers := group.Peers
update := make([]string, 0, len(group.Peers)) update := make([]string, 0, len(group.Peers))
for _, pid := range group.Peers { for _, pid := range group.Peers {
peer, ok := a.Peers[pid] peer, ok := a.Peers[pid]
@@ -1014,7 +1026,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
} }
} }
group.Peers = update group.Peers = update
groupUpdates[gid] = difference(oldPeers, group.Peers)
} }
return groupUpdates
} }
// BuildManager creates a new DefaultAccountManager with a provided Store // BuildManager creates a new DefaultAccountManager with a provided Store
@@ -1176,6 +1191,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, err return nil, err
} }
err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
if err != nil {
return nil, fmt.Errorf("groups propagation failed: %w", err)
}
updatedAccount := account.UpdateSettings(newSettings) updatedAccount := account.UpdateSettings(newSettings)
err = am.Store.SaveAccount(ctx, account) err = am.Store.SaveAccount(ctx, account)
@@ -1186,21 +1206,39 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return updatedAccount, nil return updatedAccount, nil
} }
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error {
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled {
event := activity.AccountPeerInactivityExpirationEnabled if newSettings.GroupsPropagationEnabled {
if !newSettings.PeerInactivityExpirationEnabled { am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil)
event = activity.AccountPeerInactivityExpirationDisabled // Todo: retroactively add user groups to all peers
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
} else { } else {
am.checkAndSchedulePeerInactivityExpiration(ctx, account) am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil)
} }
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
} }
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { return nil
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) }
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled {
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
} else {
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
event := activity.AccountPeerInactivityExpirationEnabled
if !newSettings.PeerInactivityExpirationEnabled {
event = activity.AccountPeerInactivityExpirationDisabled
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
} else {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
} }
return nil return nil
@@ -1435,7 +1473,7 @@ func isNil(i idp.Manager) bool {
// addAccountIDToIDPAppMeta update user's app metadata in idp manager // addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID) accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -2029,7 +2067,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error getting user: %w", err) return fmt.Errorf("error getting user: %w", err)
} }
groups, err := transaction.GetAccountGroups(ctx, accountID) groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return fmt.Errorf("error getting account groups: %w", err) return fmt.Errorf("error getting account groups: %w", err)
} }
@@ -2059,7 +2097,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
// Propagate changes to peers if group propagation is enabled // Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
groups, err = transaction.GetAccountGroups(ctx, accountID) groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return fmt.Errorf("error getting account groups: %w", err) return fmt.Errorf("error getting account groups: %w", err)
} }
@@ -2083,7 +2121,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error saving groups: %w", err) return fmt.Errorf("error saving groups: %w", err)
} }
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err) return fmt.Errorf("error incrementing network serial: %w", err)
} }
} }
@@ -2101,7 +2139,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
for _, g := range addNewGroups { for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else { } else {
@@ -2114,7 +2152,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
for _, g := range removeOldGroups { for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else { } else {
@@ -2127,14 +2165,19 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
if err != nil { if err != nil {
return fmt.Errorf("error getting account: %w", err) return err
} }
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) { newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
if err != nil {
return err
}
if removedGroupAffectsPeers || newGroupsAffectsPeers {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
} }
@@ -2290,12 +2333,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, status.NewGetAccountError(err)
} }
peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account) peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
} }
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account) err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account)
@@ -2314,12 +2357,12 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return status.NewGetAccountError(err)
} }
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account) err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
} }
return nil return nil
@@ -2335,6 +2378,9 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
unlock := am.Store.AcquireReadLockByUID(ctx, accountID) unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer unlock() defer unlock()
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer unlockPeer()
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
@@ -2398,12 +2444,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context,
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
updatedAccount, err := am.Store.GetAccount(ctx, accountID) am.updateAccountPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
return
}
am.updateAccountPeers(ctx, updatedAccount)
} }
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {

View File

@@ -1413,11 +1413,13 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
group := group.Group{ err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
ID: "groupA", ID: "groupA",
Name: "GroupA", Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID}, Peers: []string{peer1.ID, peer2.ID, peer3.ID},
} })
require.NoError(t, err, "failed to save group")
policy := Policy{ policy := Policy{
Enabled: true, Enabled: true,
@@ -1460,7 +1462,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
return return
} }
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil {
t.Errorf("delete group: %v", err) t.Errorf("delete group: %v", err)
return return
} }
@@ -2714,7 +2716,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0) assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })
@@ -2734,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1) assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })
@@ -2773,7 +2775,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.syncJWTGroups(context.Background(), "accountID", claims) err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID") groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added") assert.Len(t, groups, 3, "new group3 should be added")

View File

@@ -148,6 +148,9 @@ const (
AccountPeerInactivityExpirationDurationUpdated Activity = 67 AccountPeerInactivityExpirationDurationUpdated Activity = 67
SetupKeyDeleted Activity = 68 SetupKeyDeleted Activity = 68
UserGroupPropagationEnabled Activity = 69
UserGroupPropagationDisabled Activity = 70
) )
var activityMap = map[Activity]Code{ var activityMap = map[Activity]Code{
@@ -222,6 +225,9 @@ var activityMap = map[Activity]Code{
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"}, AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"}, AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"}, SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"},
UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"},
UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"},
} }
// StringCode returns a string code of the activity // StringCode returns a string code of the activity

View File

@@ -146,7 +146,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
} }
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil

View File

@@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
// It is recommended to call it with locking FileStore.mux // It is recommended to call it with locking FileStore.mux
func (s *FileStore) persist(ctx context.Context, file string) error { func (s *FileStore) persist(ctx context.Context, file string) error {
start := time.Now() start := time.Now()
err := util.WriteJson(file, s) err := util.WriteJson(context.Background(), file, s)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -6,11 +6,12 @@ import (
"fmt" "fmt"
"slices" "slices"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@@ -27,18 +28,17 @@ func (e *GroupLinkError) Error() string {
// CheckGroupPermissions validates if a user has the necessary permissions to view groups // CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID { if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "groups are blocked for users") return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
} }
return nil return nil
@@ -49,8 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
} }
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
@@ -58,13 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
return am.Store.GetAccountGroups(ctx, accountID)
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
} }
// SaveGroup object of the peers // SaveGroup object of the peers
@@ -77,79 +75,74 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
// SaveGroups adds new groups to the account. // SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock. // Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method. // It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var eventsToStore []func() var eventsToStore []func()
var groupsToSave []*nbgroup.Group
var updateAccountPeers bool
for _, newGroup := range newGroups { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { groupIDs := make([]string, 0, len(groups))
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) for _, newGroup := range groups {
} if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
}
} }
// Avoid duplicate groups only for the API issued groups. newGroup.AccountID = accountID
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. groupsToSave = append(groupsToSave, newGroup)
if existingGroup != nil { groupIDs = append(groupIDs, newGroup.ID)
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String() events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
} }
for _, peerID := range newGroup.Peers { updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
if account.Peers[peerID] == nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) return err
}
} }
oldGroup := account.Groups[newGroup.ID] if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
account.Groups[newGroup.ID] = newGroup return err
}
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account) return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
eventsToStore = append(eventsToStore, events...) })
} if err != nil {
newGroupIDs := make([]string, 0, len(newGroups))
for _, newGroup := range newGroups {
newGroupIDs = append(newGroupIDs, newGroup.ID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
if areGroupChangesAffectPeers(account, newGroupIDs) {
am.updateAccountPeers(ctx, account)
}
for _, storeEvent := range eventsToStore { for _, storeEvent := range eventsToStore {
storeEvent() storeEvent()
} }
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil return nil
} }
// prepareGroupEvents prepares a list of event functions to be stored. // prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() { func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
var eventsToStore []func() var eventsToStore []func()
addedPeers := make([]string, 0) addedPeers := make([]string, 0)
removedPeers := make([]string, 0) removedPeers := make([]string, 0)
if oldGroup != nil { oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
if err == nil && oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers) addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else { } else {
@@ -159,35 +152,42 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
}) })
} }
for _, p := range addedPeers { modifiedPeers := slices.Concat(addedPeers, removedPeers)
peer := account.Peers[p] peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers)
if peer == nil { if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
return nil
}
for _, peerID := range addedPeers {
peer, ok := peers[peerID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID)
continue continue
} }
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, meta := map[string]any{
map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID,
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), }
}) am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta)
}) })
} }
for _, p := range removedPeers { for _, peerID := range removedPeers {
peer := account.Peers[p] peer, ok := peers[peerID]
if peer == nil { if !ok {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID)
continue continue
} }
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, meta := map[string]any{
map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID,
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), }
}) am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta)
}) })
} }
@@ -210,42 +210,10 @@ func difference(a, b []string) []string {
} }
// DeleteGroup object of the peers. // DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return nil
}
allGroup, err := account.GetGroupAll()
if err != nil {
return err
}
if allGroup.ID == groupID {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if err = validateDeleteGroup(account, group, userId); err != nil {
return err
}
delete(account.Groups, groupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
return nil
} }
// DeleteGroups deletes groups from an account. // DeleteGroups deletes groups from an account.
@@ -254,93 +222,94 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
// //
// If an error occurs while deleting a group, the function skips it and continues deleting other groups. // If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end. // Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error { func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountId) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var allErrors error var allErrors error
var groupIDsToDelete []string
var deletedGroups []*nbgroup.Group
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs)) err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
group, ok := account.Groups[groupID] group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
if !ok { if err != nil {
continue allErrors = errors.Join(allErrors, err)
continue
}
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
groupIDsToDelete = append(groupIDsToDelete, groupID)
deletedGroups = append(deletedGroups, group)
} }
if err := validateDeleteGroup(account, group, userId); err != nil { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) return err
continue
} }
delete(account.Groups, groupID) return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
deletedGroups = append(deletedGroups, group) })
} if err != nil {
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
for _, g := range deletedGroups { for _, group := range deletedGroups {
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
} }
return allErrors return allErrors
} }
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
}
// GroupAddPeer appends peer to the group // GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) var group *nbgroup.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.AddPeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
})
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID] if updateAccountPeers {
if !ok { am.updateAccountPeers(ctx, accountID)
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
add := true
for _, itemID := range group.Peers {
if itemID == peerID {
add = false
break
}
}
if add {
group.Peers = append(group.Peers, peerID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
} }
return nil return nil
@@ -351,90 +320,162 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) var group *nbgroup.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.RemovePeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
})
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID] if updateAccountPeers {
if !ok { am.updateAccountPeers(ctx, accountID)
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
account.Network.IncSerial()
for i, itemID := range group.Peers {
if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(ctx, account); err != nil {
return err
}
}
}
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
} }
return nil return nil
} }
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error { // validateNewGroup validates the new group for existence and required fields.
func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
if err != nil {
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
return err
}
}
// Prevent duplicate groups for API-issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
return nil
}
func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user // disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration { if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser := account.Users[userID] executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
if executingUser == nil { if err != nil {
return status.Errorf(status.NotFound, "user not found") return err
} }
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
} }
} }
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked { if group.IsGroupAll() {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)} return &GroupLinkError{"route", string(linkedRoute.NetID)}
} }
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked { if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name} return &GroupLinkError{"name server groups", linkedDns.Name}
} }
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked { if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name} return &GroupLinkError{"policy", linkedPolicy.Name}
} }
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked { if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name} return &GroupLinkError{"setup key", linkedSetupKey.Name}
} }
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked { if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id} return &GroupLinkError{"user", linkedUser.Id}
} }
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) { return checkGroupLinkedToSettings(ctx, transaction, group)
}
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name} return &GroupLinkError{"disabled DNS management groups", group.Name}
} }
if account.Settings.Extra != nil { settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) { if err != nil {
return &GroupLinkError{"integrated validator", group.Name} return err
} }
if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
} }
return nil return nil
} }
// isGroupLinkedToRoute checks if a group is linked to any route in the account. // isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
}
for _, r := range routes { for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r return true, r
} }
} }
return false, nil return false, nil
} }
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. // isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
}
for _, policy := range policies { for _, policy := range policies {
for _, rule := range policy.Rules { for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
@@ -446,7 +487,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
} }
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
}
for _, dns := range nameServerGroups { for _, dns := range nameServerGroups {
for _, g := range dns.Groups { for _, g := range dns.Groups {
if g == groupID { if g == groupID {
@@ -454,11 +501,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
} }
} }
} }
return false, nil return false, nil
} }
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) { func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
}
for _, setupKey := range setupKeys { for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) { if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey return true, setupKey
@@ -468,7 +522,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
} }
// isGroupLinkedToUser checks if a group is linked to any user in the account. // isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
}
for _, user := range users { for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) { if slices.Contains(user.AutoGroups, groupID) {
return true, user return true, user
@@ -477,6 +537,35 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
return false, nil return false, nil
} }
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return false, err
}
for _, groupID := range groupIDs {
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
return true, nil
}
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
return true, nil
}
}
return false, nil
}
// anyGroupHasPeers checks if any of the given groups in the account have peers. // anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool { func anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
@@ -486,22 +575,3 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
} }
return false return false
} }
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
return true
}
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
return true
}
}
return false
}

View File

@@ -49,3 +49,35 @@ func (g *Group) Copy() *Group {
func (g *Group) HasPeers() bool { func (g *Group) HasPeers() bool {
return len(g.Peers) > 0 return len(g.Peers) > 0
} }
// IsGroupAll checks if the group is a default "All" group.
func (g *Group) IsGroupAll() bool {
return g.Name == "All"
}
// AddPeer adds peerID to Peers if not present, returning true if added.
func (g *Group) AddPeer(peerID string) bool {
if peerID == "" {
return false
}
for _, itemID := range g.Peers {
if itemID == peerID {
return false
}
}
g.Peers = append(g.Peers, peerID)
return true
}
// RemovePeer removes peerID from Peers if present, returning true if removed.
func (g *Group) RemovePeer(peerID string) bool {
for i, itemID := range g.Peers {
if itemID == peerID {
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
return true
}
}
return false
}

View File

@@ -0,0 +1,90 @@
package group
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAddPeer(t *testing.T) {
t.Run("add new peer to empty slice", func(t *testing.T) {
group := &Group{Peers: []string{}}
peerID := "peer1"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add new peer to nil slice", func(t *testing.T) {
group := &Group{Peers: nil}
peerID := "peer1"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add new peer to non-empty slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer3"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add duplicate peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer1"
assert.False(t, group.AddPeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
t.Run("add empty peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := ""
assert.False(t, group.AddPeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
}
func TestRemovePeer(t *testing.T) {
t.Run("remove existing peer from slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2", "peer3"}}
peerID := "peer2"
assert.True(t, group.RemovePeer(peerID))
assert.NotContains(t, group.Peers, peerID)
assert.Equal(t, 2, len(group.Peers))
})
t.Run("remove peer from empty slice", func(t *testing.T) {
group := &Group{Peers: []string{}}
peerID := "peer1"
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 0, len(group.Peers))
})
t.Run("remove peer from nil slice", func(t *testing.T) {
group := &Group{Peers: nil}
peerID := "peer1"
assert.False(t, group.RemovePeer(peerID))
assert.Nil(t, group.Peers)
})
t.Run("remove non-existent peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer3"
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
t.Run("remove peer from single-item slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1"}}
peerID := "peer1"
assert.True(t, group.RemovePeer(peerID))
assert.Equal(t, 0, len(group.Peers))
assert.NotContains(t, group.Peers, peerID)
})
t.Run("remove empty peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := ""
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
}

View File

@@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
{ {
name: "delete non-existent group", name: "delete non-existent group",
groupIDs: []string{"non-existent-group"}, groupIDs: []string{"non-existent-group"},
expectedDeleted: []string{"non-existent-group"}, expectedReasons: []string{"group: non-existent-group not found"},
}, },
{ {
name: "delete multiple groups with mixed results", name: "delete multiple groups with mixed results",

View File

@@ -180,6 +180,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP) peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
return mapError(ctx, err) return mapError(ctx, err)
} }
@@ -207,6 +208,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// handleUpdates sends updates to the connected peer until the updates channel is closed. // handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
for { for {
select { select {
// condition when there are some updates // condition when there are some updates
@@ -260,10 +262,15 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p
unlock := s.acquirePeerLockByUID(ctx, peer.Key) unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock() defer unlock()
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
if err != nil {
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
}
s.peersUpdateManager.CloseChannel(ctx, peer.ID) s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.secretsManager.CancelRefresh(peer.ID) s.secretsManager.CancelRefresh(peer.ID)
s.ephemeralManager.OnPeerDisconnected(ctx, peer) s.ephemeralManager.OnPeerDisconnected(ctx, peer)
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
} }
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) { func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {

View File

@@ -439,17 +439,13 @@ components:
example: 5 example: 5
required: required:
- accessible_peers_count - accessible_peers_count
SetupKey: SetupKeyBase:
type: object type: object
properties: properties:
id: id:
description: Setup Key ID description: Setup Key ID
type: string type: string
example: 2531583362 example: 2531583362
key:
description: Setup Key value
type: string
example: A616097E-FCF0-48FA-9354-CA4A61142761
name: name:
description: Setup key name identifier description: Setup key name identifier
type: string type: string
@@ -518,22 +514,31 @@ components:
- updated_at - updated_at
- usage_limit - usage_limit
- ephemeral - ephemeral
SetupKeyClear:
allOf:
- $ref: '#/components/schemas/SetupKeyBase'
- type: object
properties:
key:
description: Setup Key as plain text
type: string
example: A616097E-FCF0-48FA-9354-CA4A61142761
required:
- key
SetupKey:
allOf:
- $ref: '#/components/schemas/SetupKeyBase'
- type: object
properties:
key:
description: Setup Key as secret
type: string
example: A6160****
required:
- key
SetupKeyRequest: SetupKeyRequest:
type: object type: object
properties: properties:
name:
description: Setup Key name
type: string
example: Default key
type:
description: Setup key type, one-off for single time usage and reusable
type: string
example: reusable
expires_in:
description: Expiration time in seconds, 0 will mean the key never expires
type: integer
minimum: 0
example: 86400
revoked: revoked:
description: Setup key revocation status description: Setup key revocation status
type: boolean type: boolean
@@ -544,21 +549,9 @@ components:
items: items:
type: string type: string
example: "ch8i4ug6lnn4g9hqv7m0" example: "ch8i4ug6lnn4g9hqv7m0"
usage_limit:
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
type: integer
example: 0
ephemeral:
description: Indicate that the peer will be ephemeral or not
type: boolean
example: true
required: required:
- name
- type
- expires_in
- revoked - revoked
- auto_groups - auto_groups
- usage_limit
CreateSetupKeyRequest: CreateSetupKeyRequest:
type: object type: object
properties: properties:
@@ -1943,7 +1936,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/SetupKey' $ref: '#/components/schemas/SetupKeyClear'
'400': '400':
"$ref": "#/components/responses/bad_request" "$ref": "#/components/responses/bad_request"
'401': '401':

View File

@@ -1062,7 +1062,94 @@ type SetupKey struct {
// Id Setup Key ID // Id Setup Key ID
Id string `json:"id"` Id string `json:"id"`
// Key Setup Key value // Key Setup Key as secret
Key string `json:"key"`
// LastUsed Setup key last usage date
LastUsed time.Time `json:"last_used"`
// Name Setup key name identifier
Name string `json:"name"`
// Revoked Setup key revocation status
Revoked bool `json:"revoked"`
// State Setup key status, "valid", "overused","expired" or "revoked"
State string `json:"state"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UpdatedAt Setup key last update date
UpdatedAt time.Time `json:"updated_at"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
// UsedTimes Usage count of setup key
UsedTimes int `json:"used_times"`
// Valid Setup key validity status
Valid bool `json:"valid"`
}
// SetupKeyBase defines model for SetupKeyBase.
type SetupKeyBase struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral bool `json:"ephemeral"`
// Expires Setup Key expiration date
Expires time.Time `json:"expires"`
// Id Setup Key ID
Id string `json:"id"`
// LastUsed Setup key last usage date
LastUsed time.Time `json:"last_used"`
// Name Setup key name identifier
Name string `json:"name"`
// Revoked Setup key revocation status
Revoked bool `json:"revoked"`
// State Setup key status, "valid", "overused","expired" or "revoked"
State string `json:"state"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UpdatedAt Setup key last update date
UpdatedAt time.Time `json:"updated_at"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
// UsedTimes Usage count of setup key
UsedTimes int `json:"used_times"`
// Valid Setup key validity status
Valid bool `json:"valid"`
}
// SetupKeyClear defines model for SetupKeyClear.
type SetupKeyClear struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral bool `json:"ephemeral"`
// Expires Setup Key expiration date
Expires time.Time `json:"expires"`
// Id Setup Key ID
Id string `json:"id"`
// Key Setup Key as plain text
Key string `json:"key"` Key string `json:"key"`
// LastUsed Setup key last usage date // LastUsed Setup key last usage date
@@ -1098,23 +1185,8 @@ type SetupKeyRequest struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key // AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral *bool `json:"ephemeral,omitempty"`
// ExpiresIn Expiration time in seconds, 0 will mean the key never expires
ExpiresIn int `json:"expires_in"`
// Name Setup Key name
Name string `json:"name"`
// Revoked Setup key revocation status // Revoked Setup key revocation status
Revoked bool `json:"revoked"` Revoked bool `json:"revoked"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
} }
// User defines model for User. // User defines model for User.

View File

@@ -184,14 +184,26 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
respBody := make([]*api.PeerBatch, 0, len(account.Peers)) peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
for _, peer := range account.Peers { if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupsMap := map[string]*nbgroup.Group{}
groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
for _, group := range groups {
groupsMap[group.ID] = group
}
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer) peerToReturn, err := h.checkPeerStatus(peer)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
} }
@@ -304,7 +316,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
} }
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum groupsInfo := []api.GroupMinimum{}
groupsChecked := make(map[string]struct{}) groupsChecked := make(map[string]struct{})
for _, group := range groups { for _, group := range groups {
_, ok := groupsChecked[group.ID] _, ok := groupsChecked[group.ID]

View File

@@ -137,11 +137,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
return return
} }
if req.Name == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
return
}
if req.AutoGroups == nil { if req.AutoGroups == nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
return return
@@ -150,7 +145,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
newKey := &server.SetupKey{} newKey := &server.SetupKey{}
newKey.AutoGroups = req.AutoGroups newKey.AutoGroups = req.AutoGroups
newKey.Revoked = req.Revoked newKey.Revoked = req.Revoked
newKey.Name = req.Name
newKey.Id = keyID newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID) newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)

View File

@@ -52,25 +52,22 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
return am.Store.SaveAccount(ctx, a) return am.Store.SaveAccount(ctx, a)
} }
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) { func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
if len(groups) == 0 { if len(groupIDs) == 0 {
return true, nil return true, nil
} }
accountsGroups, err := am.ListGroups(ctx, accountId)
if err != nil { err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
return false, err for _, groupID := range groupIDs {
} _, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
for _, group := range groups { if err != nil {
var found bool return err
for _, accountGroup := range accountsGroups {
if accountGroup.ID == group {
found = true
break
} }
} }
if !found { return nil
return false, nil })
} if err != nil {
return false, err
} }
return true, nil return true, nil

View File

@@ -45,7 +45,6 @@ type MockAccountManager struct {
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
@@ -354,14 +353,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented") return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
} }
// ListGroups mock implementation of ListGroups from server.AccountManager interface
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
if am.ListGroupsFunc != nil {
return am.ListGroupsFunc(ctx, accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
}
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
if am.GroupAddPeerFunc != nil { if am.GroupAddPeerFunc != nil {

View File

@@ -71,7 +71,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
} }
if anyGroupHasPeers(account, newNSGroup.Groups) { if anyGroupHasPeers(account, newNSGroup.Groups) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
@@ -106,7 +106,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
} }
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
@@ -136,7 +136,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
} }
if anyGroupHasPeers(account, nsGroup.Groups) { if anyGroupHasPeers(account, nsGroup.Groups) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())

View File

@@ -110,14 +110,16 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error { func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error {
peer, err := account.FindPeerByPubKey(peerPubKey) peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to find peer by pub key: %w", err)
} }
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account) expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to update peer status and location: %w", err)
} }
log.WithContext(ctx).Debugf("mark peer %s connected: %t", peer.ID, connected)
if peer.AddedWithSSOLogin() { if peer.AddedWithSSOLogin() {
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account) am.checkAndSchedulePeerLoginExpiration(ctx, account)
@@ -131,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
if expired { if expired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from // we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again. // the expired one. Here we notify them that connection is now allowed again.
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account.Id)
} }
return nil return nil
@@ -166,9 +168,11 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
account.UpdatePeer(peer) account.UpdatePeer(peer)
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
if err != nil { if err != nil {
return false, err return false, fmt.Errorf("failed to save peer status: %w", err)
} }
return oldStatus.LoginExpired, nil return oldStatus.LoginExpired, nil
@@ -267,7 +271,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
} }
if peerLabelUpdated || requiresPeerUpdates { if peerLabelUpdated || requiresPeerUpdates {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return peer, nil return peer, nil
@@ -331,7 +335,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err return err
} }
updateAccountPeers := isPeerInActiveGroup(account, peerID) updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID)
if err != nil {
return err
}
err = am.deletePeers(ctx, account, []string{peerID}, userID) err = am.deletePeers(ctx, account, []string{peerID}, userID)
if err != nil { if err != nil {
@@ -344,7 +351,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
} }
if updateAccountPeers { if updateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
@@ -551,7 +558,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return fmt.Errorf("failed to add peer to account: %w", err) return fmt.Errorf("failed to add peer to account: %w", err)
} }
err = transaction.IncrementNetworkSerial(ctx, accountID) err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -587,17 +594,22 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil { if err != nil {
return nil, nil, nil, fmt.Errorf("error getting account: %w", err) return nil, nil, nil, status.NewGetAccountError(err)
} }
allGroup, err := account.GetGroupAll() allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err) return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
} }
groupsToAdd = append(groupsToAdd, allGroup.ID) groupsToAdd = append(groupsToAdd, allGroup.ID)
if areGroupChangesAffectPeers(account, groupsToAdd) {
am.updateAccountPeers(ctx, account) newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd)
if err != nil {
return nil, nil, nil, err
}
if newGroupsAffectsPeers {
am.updateAccountPeers(ctx, accountID)
} }
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(account)
@@ -640,7 +652,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
if peer.UserID != "" { if peer.UserID != "" {
user, err := account.FindUser(peer.UserID) user, err := account.FindUser(peer.UserID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, fmt.Errorf("failed to get user: %w", err)
} }
err = checkIfPeerOwnerIsBlocked(peer, user) err = checkIfPeerOwnerIsBlocked(peer, user)
@@ -655,19 +667,22 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
updated := peer.UpdateMetaIfNew(sync.Meta) updated := peer.UpdateMetaIfNew(sync.Meta)
if updated { if updated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
account.Peers[peer.ID] = peer
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
err = am.Store.SavePeer(ctx, account.Id, peer) err = am.Store.SavePeer(ctx, account.Id, peer)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err)
} }
if sync.UpdateAccountPeers { if sync.UpdateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account.Id)
} }
} }
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, fmt.Errorf("failed to validate peer: %w", err)
} }
var postureChecks []*posture.Checks var postureChecks []*posture.Checks
@@ -680,12 +695,12 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
} }
if isStatusChanged { if isStatusChanged {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account.Id)
} }
validPeersMap, err := am.GetValidatedPeers(account) validPeersMap, err := am.GetValidatedPeers(account)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
} }
postureChecks = am.getPeerPostureChecks(account, peer) postureChecks = am.getPeerPostureChecks(account, peer)
@@ -765,7 +780,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
} }
} }
groups, err := am.Store.GetAccountGroups(ctx, accountID) groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -787,6 +802,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
updated := peer.UpdateMetaIfNew(login.Meta) updated := peer.UpdateMetaIfNew(login.Meta)
if updated { if updated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
shouldStorePeer = true shouldStorePeer = true
} }
@@ -811,7 +827,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
} }
if updateRemotePeers || isStatusChanged { if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer)
@@ -974,7 +990,13 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
// updateAccountPeers updates all peers that belong to an account. // updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers. // Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) { func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err)
return
}
start := time.Now() start := time.Now()
defer func() { defer func() {
if am.metrics != nil { if am.metrics != nil {
@@ -1028,12 +1050,12 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
// IsPeerInActiveGroup checks if the given peer is part of a group that is used // IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration. // in an active DNS, route, or ACL configuration.
func isPeerInActiveGroup(account *Account, peerID string) bool { func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) {
peerGroupIDs := make([]string, 0) peerGroupIDs := make([]string, 0)
for _, group := range account.Groups { for _, group := range account.Groups {
if slices.Contains(group.Peers, peerID) { if slices.Contains(group.Peers, peerID) {
peerGroupIDs = append(peerGroupIDs, group.ID) peerGroupIDs = append(peerGroupIDs, group.ID)
} }
} }
return areGroupChangesAffectPeers(account, peerGroupIDs) return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs)
} }

View File

@@ -877,7 +877,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
start := time.Now() start := time.Now()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.updateAccountPeers(ctx, account) manager.updateAccountPeers(ctx, account.Id)
} }
duration := time.Since(start) duration := time.Since(start)

View File

@@ -377,7 +377,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers { if updateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
@@ -406,7 +406,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
if anyGroupHasPeers(account, policy.ruleGroups()) { if anyGroupHasPeers(account, policy.ruleGroups()) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil

View File

@@ -69,7 +69,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil

View File

@@ -238,7 +238,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
} }
if isRouteChangeAffectPeers(account, &newRoute) { if isRouteChangeAffectPeers(account, &newRoute) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
@@ -324,7 +324,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
} }
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
@@ -356,7 +356,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
if isRouteChangeAffectPeers(account, routy) { if isRouteChangeAffectPeers(account, routy) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil

View File

@@ -1091,7 +1091,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.ListGroups(context.Background(), account.Id) groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id)
require.NoError(t, err) require.NoError(t, err)
var groupHA1, groupHA2 *nbgroup.Group var groupHA1, groupHA2 *nbgroup.Group
for _, group := range groups { for _, group := range groups {

View File

@@ -4,8 +4,8 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
b64 "encoding/base64" b64 "encoding/base64"
"fmt"
"hash/fnv" "hash/fnv"
"slices"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -229,32 +229,43 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil { if user.AccountID != accountID {
return nil, err return nil, status.NewUserNotPartOfAccountError()
} }
setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) if user.IsRegularUser() {
account.SetupKeys[setupKey.Key] = setupKey return nil, status.NewAdminPermissionError()
err = am.Store.SaveAccount(ctx, account) }
var setupKey *SetupKey
var plainKey string
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil {
return err
}
setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
setupKey.AccountID = accountID
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey)
eventsToStore = append(eventsToStore, events...)
return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey)
})
if err != nil { if err != nil {
return nil, status.Errorf(status.Internal, "failed adding account key") return nil, err
} }
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta())
for _, storeEvent := range eventsToStore {
for _, g := range setupKey.AutoGroups { storeEvent()
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
}
} }
// for the creation return the plain key to the caller // for the creation return the plain key to the caller
@@ -266,45 +277,61 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
// SaveSetupKey saves the provided SetupKey to the database overriding the existing one. // SaveSetupKey saves the provided SetupKey to the database overriding the existing one.
// Due to the unique nature of a SetupKey certain properties must not be overwritten // Due to the unique nature of a SetupKey certain properties must not be overwritten
// (e.g. the key itself, creation date, ID, etc). // (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. // These properties are overwritten: AutoGroups, Revoked (only from false to true), and the UpdatedAt. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if keyToSave == nil { if keyToSave == nil {
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
} }
account, err := am.Store.GetAccount(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
var oldKey *SetupKey var oldKey *SetupKey
for _, key := range account.SetupKeys { var newKey *SetupKey
if key.Id == keyToSave.Id { var eventsToStore []func()
oldKey = key.Copy()
break err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil {
return err
} }
}
if oldKey == nil {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil { oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id)
return nil, err if err != nil {
} return err
}
// only auto groups, revoked status, and name can be updated for now if oldKey.Revoked && !keyToSave.Revoked {
newKey := oldKey.Copy() return status.Errorf(status.InvalidArgument, "can't un-revoke a revoked setup key")
newKey.Name = keyToSave.Name }
newKey.AutoGroups = keyToSave.AutoGroups
newKey.Revoked = keyToSave.Revoked
newKey.UpdatedAt = time.Now().UTC()
account.SetupKeys[newKey.Key] = newKey // only auto groups, revoked status (from false to true) can be updated
newKey = oldKey.Copy()
newKey.AutoGroups = keyToSave.AutoGroups
newKey.Revoked = keyToSave.Revoked
newKey.UpdatedAt = time.Now().UTC()
if err = am.Store.SaveAccount(ctx, account); err != nil { addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups)
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups)
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey)
eventsToStore = append(eventsToStore, events...)
return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey)
})
if err != nil {
return nil, err return nil, err
} }
@@ -312,30 +339,9 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
am.StoreEvent(ctx, userID, newKey.Id, accountID, activity.SetupKeyRevoked, newKey.EventMeta()) am.StoreEvent(ctx, userID, newKey.Id, accountID, activity.SetupKeyRevoked, newKey.EventMeta())
} }
defer func() { for _, storeEvent := range eventsToStore {
addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) storeEvent()
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) }
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
}
}
for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
}
}
}()
return newKey, nil return newKey, nil
} }
@@ -347,16 +353,15 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.NewUnauthorizedToViewSetupKeysError() return nil, status.NewUserNotPartOfAccountError()
} }
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) if user.IsRegularUser() {
if err != nil { return nil, status.NewAdminPermissionError()
return nil, err
} }
return setupKeys, nil return am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
} }
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
@@ -366,11 +371,15 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.NewUnauthorizedToViewSetupKeysError() return nil, status.NewUserNotPartOfAccountError()
} }
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -387,21 +396,29 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user: %w", err) return err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return status.NewUnauthorizedToViewSetupKeysError() return status.NewUserNotPartOfAccountError()
} }
deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var deletedSetupKey *SetupKey
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
if err != nil {
return err
}
return transaction.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID)
})
if err != nil { if err != nil {
return fmt.Errorf("failed to get setup key: %w", err) return err
}
err = am.Store.DeleteSetupKey(ctx, accountID, keyID)
if err != nil {
return fmt.Errorf("failed to delete setup key: %w", err)
} }
am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta()) am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta())
@@ -409,15 +426,62 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
return nil return nil
} }
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error { func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error {
for _, group := range autoGroups { groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs)
g, ok := account.Groups[group] if err != nil {
return err
}
for _, groupID := range autoGroupIDs {
group, ok := groups[groupID]
if !ok { if !ok {
return status.Errorf(status.NotFound, "group %s doesn't exist", group) return status.Errorf(status.NotFound, "group not found: %s", groupID)
} }
if g.Name == "All" {
return status.Errorf(status.InvalidArgument, "can't add All group to the setup key") if group.IsGroupAll() {
return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key")
} }
} }
return nil return nil
} }
// prepareSetupKeyEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() {
var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
return nil
}
for _, g := range removedGroups {
group, ok := groups[g]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g)
continue
}
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta)
})
}
for _, g := range addedGroups {
group, ok := groups[g]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g)
continue
}
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta)
})
}
return eventsToStore
}

View File

@@ -56,11 +56,9 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
} }
autoGroups := []string{"group_1", "group_2"} autoGroups := []string{"group_1", "group_2"}
newKeyName := "my-new-test-key"
revoked := true revoked := true
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
Id: key.Id, Id: key.Id,
Name: newKeyName,
Revoked: revoked, Revoked: revoked,
AutoGroups: autoGroups, AutoGroups: autoGroups,
}, userID) }, userID)
@@ -68,7 +66,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt, assertKey(t, newKey, keyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
key.Id, time.Now().UTC(), autoGroups, true) key.Id, time.Now().UTC(), autoGroups, true)
// check the corresponding events that should have been generated // check the corresponding events that should have been generated
@@ -76,7 +74,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
assert.NotNil(t, ev) assert.NotNil(t, ev)
assert.Equal(t, account.Id, ev.AccountID) assert.Equal(t, account.Id, ev.AccountID)
assert.Equal(t, newKeyName, ev.Meta["name"]) assert.Equal(t, keyName, ev.Meta["name"])
assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"])) assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"]))
assert.NotEmpty(t, ev.Meta["key"]) assert.NotEmpty(t, ev.Meta["key"])
assert.Equal(t, userID, ev.InitiatorID) assert.Equal(t, userID, ev.InitiatorID)
@@ -89,7 +87,6 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
autoGroups = append(autoGroups, groupAll.ID) autoGroups = append(autoGroups, groupAll.ID)
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ _, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
Id: key.Id, Id: key.Id,
Name: newKeyName,
Revoked: revoked, Revoked: revoked,
AutoGroups: autoGroups, AutoGroups: autoGroups,
}, userID) }, userID)
@@ -213,22 +210,41 @@ func TestGetSetupKeys(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
ID: "group_1",
Name: "group_name_1",
Peers: []string{},
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ type testCase struct {
ID: "group_2", name string
Name: "group_name_2", keyId string
Peers: []string{}, expectedFailure bool
}) }
if err != nil {
t.Fatal(err) testCase1 := testCase{
name: "Should get existing Setup Key",
keyId: plainKey.Id,
expectedFailure: false,
}
testCase2 := testCase{
name: "Should fail to get non-existent Setup Key",
keyId: "some key",
expectedFailure: true,
}
for _, tCase := range []testCase{testCase1, testCase2} {
t.Run(tCase.name, func(t *testing.T) {
key, err := manager.GetSetupKey(context.Background(), account.Id, userID, tCase.keyId)
if tCase.expectedFailure {
if err == nil {
t.Fatal("expected to fail")
}
return
}
assert.NotEqual(t, plainKey.Key, key.Key)
})
} }
} }
@@ -449,3 +465,31 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
} }
}) })
} }
func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
if err != nil {
t.Fatal(err)
}
key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
assert.NoError(t, err)
// revoke the key
updateKey := key.Copy()
updateKey.Revoked = true
_, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID)
assert.NoError(t, err)
// re-activate revoked key
updateKey.Revoked = false
_, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID)
assert.Error(t, err, "should not allow to update revoked key")
}

View File

@@ -33,12 +33,13 @@ import (
) )
const ( const (
storeSqliteFileName = "store.db" storeSqliteFileName = "store.db"
idQueryCondition = "id = ?" idQueryCondition = "id = ?"
keyQueryCondition = "key = ?" keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?" accountAndIDQueryCondition = "account_id = ? and id = ?"
accountIDCondition = "account_id = ?" accountAndIDsQueryCondition = "account_id = ? AND id IN ?"
peerNotFoundFMT = "peer %s not found" accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found"
) )
// SqlStore represents an account storage backed by a Sql DB persisted to disk // SqlStore represents an account storage backed by a Sql DB persisted to disk
@@ -485,9 +486,10 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey) result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.NewSetupKeyNotFoundError(setupKey)
} }
return nil, status.NewSetupKeyNotFoundError(result.Error) log.WithContext(ctx).Errorf("failed to get account by setup key from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get account by setup key from store")
} }
if key.AccountID == "" { if key.AccountID == "" {
@@ -554,9 +556,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return &user, nil return &user, nil
} }
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
var users []*User var users []*User
result := s.db.Find(&users, accountIDCondition, accountID) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@@ -568,15 +570,15 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*Us
return users, nil return users, nil
} }
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group var groups []*nbgroup.Group
result := s.db.Find(&groups, accountIDCondition, accountID) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
} }
log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error) log.WithContext(ctx).Errorf("failed to get account groups from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting groups from store") return nil, status.Errorf(status.Internal, "failed to get account groups from the store")
} }
return groups, nil return groups, nil
@@ -756,9 +758,10 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.NewSetupKeyNotFoundError(setupKey)
} }
return "", status.NewSetupKeyNotFoundError(result.Error) log.WithContext(ctx).Errorf("failed to get account ID by setup key from store: %v", result.Error)
return "", status.Errorf(status.Internal, "failed to get account ID by setup key from store")
} }
if accountID == "" { if accountID == "" {
@@ -855,7 +858,6 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID) return status.NewUserNotFoundError(userID)
} }
return status.NewGetUserFromStoreError() return status.NewGetUserFromStoreError()
} }
user.LastLogin = lastLogin user.LastLogin = lastLogin
@@ -986,9 +988,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
First(&setupKey, keyQueryCondition, key) First(&setupKey, keyQueryCondition, key)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found") return nil, status.NewSetupKeyNotFoundError(key)
} }
return nil, status.NewSetupKeyNotFoundError(result.Error) log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store")
} }
return &setupKey, nil return &setupKey, nil
} }
@@ -1006,7 +1009,7 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "setup key not found") return status.NewSetupKeyNotFoundError(setupKeyID)
} }
return nil return nil
@@ -1042,7 +1045,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group not found for account") return status.NewGroupNotFoundError(groupID)
} }
return status.Errorf(status.Internal, "issue finding group: %s", result.Error) return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
@@ -1076,15 +1079,51 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
return nil return nil
} }
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { // GetPeerByID retrieves a peer by its ID and account ID.
result := s.db.Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
var peer *nbpeer.Peer
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&peer, accountAndIDQueryCondition, accountID, peerID)
if result.Error != nil { if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
log.WithContext(ctx).Errorf("failed to get peer from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get peer from store")
}
return peer, nil
}
// GetPeersByIDs retrieves peers by their IDs and account ID.
func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get peers by ID's from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get peers by ID's from the store")
}
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer
}
return peersMap, nil
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to increment network serial count in store")
} }
return nil return nil
} }
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error { func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
startTime := time.Now()
tx := s.db.Begin() tx := s.db.Begin()
if tx.Error != nil { if tx.Error != nil {
return tx.Error return tx.Error
@@ -1095,12 +1134,21 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
tx.Rollback() tx.Rollback()
return err return err
} }
return tx.Commit().Error
err = tx.Commit().Error
log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime))
if s.metrics != nil {
s.metrics.StoreMetrics().CountTransactionDuration(time.Since(startTime))
}
return err
} }
func (s *SqlStore) withTx(tx *gorm.DB) Store { func (s *SqlStore) withTx(tx *gorm.DB) Store {
return &SqlStore{ return &SqlStore{
db: tx, db: tx,
storeEngine: s.storeEngine,
} }
} }
@@ -1152,12 +1200,22 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
} }
// GetGroupByID retrieves a group by ID and account ID. // GetGroupByID retrieves a group by ID and account ID.
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) { func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) {
return getRecordByID[nbgroup.Group](s.db.Preload(clause.Associations), lockStrength, groupID, accountID) var group *nbgroup.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewGroupNotFoundError(groupID)
}
log.WithContext(ctx).Errorf("failed to get group from store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get group from store")
}
return group, nil
} }
// GetGroupByName retrieves a group by name and account ID. // GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) {
var group nbgroup.Group var group nbgroup.Group
// TODO: This fix is accepted for now, but if we need to handle this more frequently // TODO: This fix is accepted for now, but if we need to handle this more frequently
@@ -1169,25 +1227,72 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
query = query.Order("json_array_length(peers) DESC") query = query.Order("json_array_length(peers) DESC")
} }
result := query.First(&group, "name = ? and account_id = ?", groupName, accountID) result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName)
if err := result.Error; err != nil { if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found") return nil, status.NewGroupNotFoundError(groupName)
} }
return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error) log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get group by name from store")
} }
return &group, nil return &group, nil
} }
// GetGroupsByIDs retrieves groups by their IDs and account ID.
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) {
var groups []*nbgroup.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store")
}
groupsMap := make(map[string]*nbgroup.Group)
for _, group := range groups {
groupsMap[group.ID] = group
}
return groupsMap, nil
}
// SaveGroup saves a group to the store. // SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil { if result.Error != nil {
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error) log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save group to store")
} }
return nil return nil
} }
// DeleteGroup deletes a group from the database.
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete group from store")
}
if result.RowsAffected == 0 {
return status.NewGroupNotFoundError(groupID)
}
return nil
}
// DeleteGroups deletes groups from the database.
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
result := s.db.Clauses(clause.Locking{Strength: string(strength)}).
Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete groups from store")
}
return nil
}
// GetAccountPolicies retrieves policies for an account. // GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID) return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID)
@@ -1220,12 +1325,57 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt
// GetAccountSetupKeys retrieves setup keys for an account. // GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
return getRecords[*SetupKey](s.db, lockStrength, accountID) var setupKeys []*SetupKey
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&setupKeys, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get setup keys from store")
}
return setupKeys, nil
} }
// GetSetupKeyByID retrieves a setup key by its ID and account ID. // GetSetupKeyByID retrieves a setup key by its ID and account ID.
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) { func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) {
return getRecordByID[SetupKey](s.db, lockStrength, setupKeyID, accountID) var setupKey *SetupKey
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewSetupKeyNotFoundError(setupKeyID)
}
log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get setup key from store")
}
return setupKey, nil
}
// SaveSetupKey saves a setup key to the database.
func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save setup key to store")
}
return nil
}
// DeleteSetupKey deletes a setup key from the database.
func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete setup key from store")
}
if result.RowsAffected == 0 {
return status.NewSetupKeyNotFoundError(keyID)
}
return nil
} }
// GetAccountNameServerGroups retrieves name server groups for an account. // GetAccountNameServerGroups retrieves name server groups for an account.
@@ -1238,10 +1388,6 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock
return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID) return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID)
} }
func (s *SqlStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error {
return deleteRecordByID[SetupKey](s.db, LockingStrengthUpdate, keyID, accountID)
}
// getRecords retrieves records from the database based on the account ID. // getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
var record []T var record []T
@@ -1274,21 +1420,3 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
} }
return &record, nil return &record, nil
} }
// deleteRecordByID deletes a record by its ID and account ID from the database.
func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error {
var record T
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(record, accountAndIDQueryCondition, accountID, recordID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "record not found")
}
return nil
}

View File

@@ -14,11 +14,10 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
route2 "github.com/netbirdio/netbird/route" route2 "github.com/netbirdio/netbird/route"
@@ -1181,7 +1180,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
t.Fatal("failed to save group") t.Fatal("failed to save group")
return err return err
} }
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID) group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID)
if err != nil { if err != nil {
t.Fatal("failed to get group") t.Fatal("failed to get group")
return err return err
@@ -1201,7 +1200,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
account, err := store.GetAccount(context.Background(), accountID) account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err) require.NoError(t, err)
users, err := store.GetAccountUsers(context.Background(), accountID) users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, users, len(account.Users)) require.Len(t, users, len(account.Users))
} }
@@ -1260,9 +1259,9 @@ func TestSqlite_GetGroupByName(t *testing.T) {
} }
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID) group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "All", group.Name) require.True(t, group.IsGroupAll())
} }
func Test_DeleteSetupKeySuccessfully(t *testing.T) { func Test_DeleteSetupKeySuccessfully(t *testing.T) {
@@ -1274,7 +1273,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID) err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, setupKeyID)
require.NoError(t, err) require.NoError(t, err)
_, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID) _, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID)
@@ -1290,6 +1289,278 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
nonExistingKeyID := "non-existing-key-id" nonExistingKeyID := "non-existing-key-id"
err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID) err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID)
require.Error(t, err) require.Error(t, err)
} }
func TestSqlStore_GetGroupsByIDs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
groupIDs []string
expectedCount int
}{
{
name: "retrieve existing groups by existing IDs",
groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
expectedCount: 2,
},
{
name: "empty group IDs list",
groupIDs: []string{},
expectedCount: 0,
},
{
name: "non-existing group IDs",
groupIDs: []string{"nonexistent1", "nonexistent2"},
expectedCount: 0,
},
{
name: "mixed existing and non-existing group IDs",
groupIDs: []string{"cfefqs706sqkneg59g4g", "nonexistent"},
expectedCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs)
require.NoError(t, err)
require.Len(t, groups, tt.expectedCount)
})
}
}
func TestSqlStore_SaveGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group := &nbgroup.Group{
ID: "group-id",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
}
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
require.NoError(t, err)
savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id")
require.NoError(t, err)
require.Equal(t, savedGroup, group)
}
func TestSqlStore_SaveGroups(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
groups := []*nbgroup.Group{
{
ID: "group-1",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
},
{
ID: "group-2",
AccountID: accountID,
Issued: "integration",
Peers: []string{"peer3", "peer4"},
},
}
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups)
require.NoError(t, err)
}
func TestSqlStore_DeleteGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
groupID string
expectError bool
}{
{
name: "delete existing group",
groupID: "cfefqs706sqkneg59g4g",
expectError: false,
},
{
name: "delete non-existing group",
groupID: "non-existing-group-id",
expectError: true,
},
{
name: "delete with empty group ID",
groupID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
} else {
require.NoError(t, err)
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID)
require.Error(t, err)
require.Nil(t, group)
}
})
}
}
func TestSqlStore_DeleteGroups(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
groupIDs []string
expectError bool
}{
{
name: "delete multiple existing groups",
groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
expectError: false,
},
{
name: "delete non-existing groups",
groupIDs: []string{"non-existing-id-1", "non-existing-id-2"},
expectError: false,
},
{
name: "delete with empty group IDs list",
groupIDs: []string{},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs)
if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
for _, groupID := range tt.groupIDs {
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
require.Error(t, err)
require.Nil(t, group)
}
}
})
}
}
func TestSqlStore_GetPeerByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
peerID string
expectError bool
}{
{
name: "retrieve existing peer",
peerID: "cfefqs706sqkneg59g4g",
expectError: false,
},
{
name: "retrieve non-existing peer",
peerID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty peer ID",
peerID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, peer)
} else {
require.NoError(t, err)
require.NotNil(t, peer)
require.Equal(t, tt.peerID, peer.ID)
}
})
}
}
func TestSqlStore_GetPeersByIDs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
peerIDs []string
expectedCount int
}{
{
name: "retrieve existing peers by existing IDs",
peerIDs: []string{"cfefqs706sqkneg59g4g", "cfeg6sf06sqkneg59g50"},
expectedCount: 2,
},
{
name: "empty peer IDs list",
peerIDs: []string{},
expectedCount: 0,
},
{
name: "non-existing peer IDs",
peerIDs: []string{"nonexistent1", "nonexistent2"},
expectedCount: 0,
},
{
name: "mixed existing and non-existing peer IDs",
peerIDs: []string{"cfeg6sf06sqkneg59g50", "nonexistent"},
expectedCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs)
require.NoError(t, err)
require.Len(t, peers, tt.expectedCount)
})
}
}

View File

@@ -3,7 +3,6 @@ package status
import ( import (
"errors" "errors"
"fmt" "fmt"
"time"
) )
const ( const (
@@ -103,22 +102,27 @@ func NewPeerLoginExpiredError() error {
} }
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key // NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
func NewSetupKeyNotFoundError(err error) error { func NewSetupKeyNotFoundError(setupKeyID string) error {
return Errorf(NotFound, "setup key not found: %s", err) return Errorf(NotFound, "setup key: %s not found", setupKeyID)
} }
func NewGetAccountFromStoreError(err error) error { func NewGetAccountFromStoreError(err error) error {
return Errorf(Internal, "issue getting account from store: %s", err) return Errorf(Internal, "issue getting account from store: %s", err)
} }
// NewUserNotPartOfAccountError creates a new Error with PermissionDenied type for a user not being part of an account
func NewUserNotPartOfAccountError() error {
return Errorf(PermissionDenied, "user is not part of this account")
}
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store // NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
func NewGetUserFromStoreError() error { func NewGetUserFromStoreError() error {
return Errorf(Internal, "issue getting user from store") return Errorf(Internal, "issue getting user from store")
} }
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context // NewAdminPermissionError creates a new Error with PermissionDenied type for actions requiring admin role.
func NewStoreContextCanceledError(duration time.Duration) error { func NewAdminPermissionError() error {
return Errorf(Internal, "store access: context canceled after %v", duration) return Errorf(PermissionDenied, "admin role required to perform this action")
} }
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key // NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
@@ -126,7 +130,12 @@ func NewInvalidKeyIDError() error {
return Errorf(InvalidArgument, "invalid key ID") return Errorf(InvalidArgument, "invalid key ID")
} }
// NewUnauthorizedToViewSetupKeysError creates a new Error with Unauthorized type for an issue getting a setup key // NewGetAccountError creates a new Error with Internal type for an issue getting account
func NewUnauthorizedToViewSetupKeysError() error { func NewGetAccountError(err error) error {
return Errorf(Unauthorized, "only users with admin power can view setup keys") return Errorf(Internal, "error getting account: %s", err)
}
// NewGroupNotFoundError creates a new Error with NotFound type for a missing group
func NewGroupNotFoundError(groupID string) error {
return Errorf(NotFound, "group: %s not found", groupID)
} }

View File

@@ -62,7 +62,7 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
@@ -70,11 +70,14 @@ type Store interface {
DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error DeleteTokenID2UserIDIndex(tokenID string) error
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
@@ -89,6 +92,8 @@ type Store interface {
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error
@@ -96,7 +101,9 @@ type Store interface {
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error)
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
@@ -105,7 +112,7 @@ type Store interface {
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, accountId string) error IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
GetInstallationID() string GetInstallationID() string
@@ -124,7 +131,6 @@ type Store interface {
// This is also a method of metrics.DataSource interface. // This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine GetStoreEngine() StoreEngine
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
DeleteSetupKey(ctx context.Context, accountID, keyID string) error
} }
type StoreEngine string type StoreEngine string

View File

@@ -13,6 +13,7 @@ type AccountManagerMetrics struct {
updateAccountPeersDurationMs metric.Float64Histogram updateAccountPeersDurationMs metric.Float64Histogram
getPeerNetworkMapDurationMs metric.Float64Histogram getPeerNetworkMapDurationMs metric.Float64Histogram
networkMapObjectCount metric.Int64Histogram networkMapObjectCount metric.Int64Histogram
peerMetaUpdateCount metric.Int64Counter
} }
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics // NewAccountManagerMetrics creates an instance of AccountManagerMetrics
@@ -44,11 +45,17 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
return nil, err return nil, err
} }
peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter", metric.WithUnit("1"))
if err != nil {
return nil, err
}
return &AccountManagerMetrics{ return &AccountManagerMetrics{
ctx: ctx, ctx: ctx,
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs, getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
updateAccountPeersDurationMs: updateAccountPeersDurationMs, updateAccountPeersDurationMs: updateAccountPeersDurationMs,
networkMapObjectCount: networkMapObjectCount, networkMapObjectCount: networkMapObjectCount,
peerMetaUpdateCount: peerMetaUpdateCount,
}, nil }, nil
} }
@@ -67,3 +74,8 @@ func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration ti
func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) { func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
metrics.networkMapObjectCount.Record(metrics.ctx, count) metrics.networkMapObjectCount.Record(metrics.ctx, count)
} }
// CountPeerMetUpdate counts the number of peer meta updates
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)
}

View File

@@ -13,6 +13,7 @@ type StoreMetrics struct {
globalLockAcquisitionDurationMs metric.Int64Histogram globalLockAcquisitionDurationMs metric.Int64Histogram
persistenceDurationMicro metric.Int64Histogram persistenceDurationMicro metric.Int64Histogram
persistenceDurationMs metric.Int64Histogram persistenceDurationMs metric.Int64Histogram
transactionDurationMs metric.Int64Histogram
ctx context.Context ctx context.Context
} }
@@ -40,11 +41,17 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er
return nil, err return nil, err
} }
transactionDurationMs, err := meter.Int64Histogram("management.store.transaction.duration.ms")
if err != nil {
return nil, err
}
return &StoreMetrics{ return &StoreMetrics{
globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro, globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro,
globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs, globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs,
persistenceDurationMicro: persistenceDurationMicro, persistenceDurationMicro: persistenceDurationMicro,
persistenceDurationMs: persistenceDurationMs, persistenceDurationMs: persistenceDurationMs,
transactionDurationMs: transactionDurationMs,
ctx: ctx, ctx: ctx,
}, nil }, nil
} }
@@ -60,3 +67,8 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) {
metrics.persistenceDurationMicro.Record(metrics.ctx, duration.Microseconds()) metrics.persistenceDurationMicro.Record(metrics.ctx, duration.Microseconds())
metrics.persistenceDurationMs.Record(metrics.ctx, duration.Milliseconds()) metrics.persistenceDurationMs.Record(metrics.ctx, duration.Milliseconds())
} }
// CountTransactionDuration counts the duration of a store persistence operation
func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) {
metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds())
}

View File

@@ -96,9 +96,12 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
if channel, ok := p.peerChannels[peerID]; ok { if channel, ok := p.peerChannels[peerID]; ok {
delete(p.peerChannels, peerID) delete(p.peerChannels, peerID)
close(channel) close(channel)
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
return
} }
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) log.WithContext(ctx).Debugf("closing updates channel: peer %s has no channel", peerID)
} }
// CloseChannels closes updates channel for each given peer // CloseChannels closes updates channel for each given peer

View File

@@ -9,14 +9,16 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/integration_reference"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
) )
const ( const (
@@ -103,6 +105,11 @@ func (u *User) IsAdminOrServiceUser() bool {
return u.HasAdminPower() || u.IsServiceUser return u.HasAdminPower() || u.IsServiceUser
} }
// IsRegularUser checks if the user is a regular user.
func (u *User) IsRegularUser() bool {
return !u.HasAdminPower() && !u.IsServiceUser
}
// ToUserInfo converts a User object to a UserInfo object. // ToUserInfo converts a User object to a UserInfo object.
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
autoGroups := u.AutoGroups autoGroups := u.AutoGroups
@@ -487,7 +494,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
if updateAccountPeers { if updateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account.Id)
} }
return nil return nil
@@ -798,15 +805,20 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
expiredPeers = append(expiredPeers, blockedPeers...) expiredPeers = append(expiredPeers, blockedPeers...)
} }
peerGroupsAdded := make(map[string][]string)
peerGroupsRemoved := make(map[string][]string)
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled { if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
// need force update all auto groups in any case they will not be duplicated // need force update all auto groups in any case they will not be duplicated
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) peerGroupsAdded = account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) peerGroupsRemoved = account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
} }
events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole) userUpdateEvents := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
eventsToStore = append(eventsToStore, events...) eventsToStore = append(eventsToStore, userUpdateEvents...)
userGroupsEvents := am.prepareUserGroupsEvents(ctx, initiatorUser.Id, oldUser, newUser, account, peerGroupsAdded, peerGroupsRemoved)
eventsToStore = append(eventsToStore, userGroupsEvents...)
updatedUserInfo, err := getUserInfo(ctx, am, newUser, account) updatedUserInfo, err := getUserInfo(ctx, am, newUser, account)
if err != nil { if err != nil {
@@ -828,7 +840,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
} }
if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) { if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account.Id)
} }
for _, storeEvent := range eventsToStore { for _, storeEvent := range eventsToStore {
@@ -865,32 +877,78 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in
}) })
} }
return eventsToStore
}
func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() {
var eventsToStore []func()
if newUser.AutoGroups != nil { if newUser.AutoGroups != nil {
removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups) removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups)
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups) addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
})
} else { removedEvents := am.handleGroupRemovedFromUser(ctx, initiatorUserID, oldUser, newUser, account, removedGroups, peerGroupsRemoved)
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id) eventsToStore = append(eventsToStore, removedEvents...)
}
} addedEvents := am.handleGroupAddedToUser(ctx, initiatorUserID, oldUser, newUser, account, addedGroups, peerGroupsAdded)
for _, g := range addedGroups { eventsToStore = append(eventsToStore, addedEvents...)
group := account.GetGroup(g) }
if group != nil { return eventsToStore
eventsToStore = append(eventsToStore, func() { }
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() {
}) var eventsToStore []func()
} for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
})
} }
} }
for groupID, peerIDs := range peerGroupsAdded {
group := account.GetGroup(groupID)
for _, peerID := range peerIDs {
peer := account.GetPeer(peerID)
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
}
am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupAddedToPeer, meta)
})
}
}
return eventsToStore
}
func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() {
var eventsToStore []func()
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
}
}
for groupID, peerIDs := range peerGroupsRemoved {
group := account.GetGroup(groupID)
for _, peerID := range peerIDs {
peer := account.GetPeer(peerID)
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
}
am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupRemovedFromPeer, meta)
})
}
}
return eventsToStore return eventsToStore
} }
@@ -1100,6 +1158,9 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error { func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error {
var peerIDs []string var peerIDs []string
for _, peer := range peers { for _, peer := range peers {
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key)
if peer.Status.LoginExpired { if peer.Status.LoginExpired {
continue continue
} }
@@ -1107,8 +1168,11 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
peer.MarkLoginExpired(true) peer.MarkLoginExpired(true)
account.UpdatePeer(peer) account.UpdatePeer(peer)
if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil { if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil {
return err return fmt.Errorf("failed saving peer status for peer %s: %s", peer.ID, err)
} }
log.WithContext(ctx).Tracef("mark peer %s login expired", peer.ID)
am.StoreEvent( am.StoreEvent(
ctx, ctx,
peer.UserID, peer.ID, account.Id, peer.UserID, peer.ID, account.Id,
@@ -1119,7 +1183,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
if len(peerIDs) != 0 { if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service // this will trigger peer disconnect from the management service
am.peersUpdateManager.CloseChannels(ctx, peerIDs) am.peersUpdateManager.CloseChannels(ctx, peerIDs)
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account.Id)
} }
return nil return nil
} }
@@ -1227,7 +1291,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
} }
if updateAccountPeers { if updateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
for targetUserID, meta := range deletedUsersMeta { for targetUserID, meta := range deletedUsersMeta {

View File

@@ -16,6 +16,8 @@ import (
const ( const (
bufferSize = 8820 bufferSize = 8820
errCloseConn = "failed to close connection to peer: %s"
) )
// Peer represents a peer connection // Peer represents a peer connection
@@ -46,6 +48,12 @@ func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle // It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
// the message accordingly. // the message accordingly.
func (p *Peer) Work() { func (p *Peer) Work() {
defer func() {
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
p.log.Errorf(errCloseConn, err)
}
}()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@@ -97,7 +105,7 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
case messages.MsgTypeClose: case messages.MsgTypeClose:
p.log.Infof("peer exited gracefully") p.log.Infof("peer exited gracefully")
if err := p.conn.Close(); err != nil { if err := p.conn.Close(); err != nil {
log.Errorf("failed to close connection to peer: %s", err) log.Errorf(errCloseConn, err)
} }
default: default:
p.log.Warnf("received unexpected message type: %s", msgType) p.log.Warnf("received unexpected message type: %s", msgType)
@@ -121,9 +129,8 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
p.log.Errorf("failed to send close message to peer: %s", p.String()) p.log.Errorf("failed to send close message to peer: %s", p.String())
} }
err = p.conn.Close() if err := p.conn.Close(); err != nil {
if err != nil { p.log.Errorf(errCloseConn, err)
p.log.Errorf("failed to close connection to peer: %s", err)
} }
} }
@@ -132,7 +139,7 @@ func (p *Peer) Close() {
defer p.connMu.Unlock() defer p.connMu.Unlock()
if err := p.conn.Close(); err != nil { if err := p.conn.Close(); err != nil {
p.log.Errorf("failed to close connection to peer: %s", err) p.log.Errorf(errCloseConn, err)
} }
} }

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@@ -14,8 +15,21 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
func WriteBytesWithRestrictedPermission(ctx context.Context, file string, bs []byte) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return fmt.Errorf("prepare config file dir: %w", err)
}
if err = EnforcePermission(file); err != nil {
return fmt.Errorf("enforce permission: %w", err)
}
return writeBytes(ctx, file, err, configDir, configFileName, bs)
}
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory // WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
func WriteJsonWithRestrictedPermission(file string, obj interface{}) error { func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file) configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil { if err != nil {
return err return err
@@ -26,18 +40,18 @@ func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
return err return err
} }
return writeJson(file, obj, configDir, configFileName) return writeJson(ctx, file, obj, configDir, configFileName)
} }
// WriteJson writes JSON config object to a file creating parent directories if required // WriteJson writes JSON config object to a file creating parent directories if required
// The output JSON is pretty-formatted // The output JSON is pretty-formatted
func WriteJson(file string, obj interface{}) error { func WriteJson(ctx context.Context, file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file) configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil { if err != nil {
return err return err
} }
return writeJson(file, obj, configDir, configFileName) return writeJson(ctx, file, obj, configDir, configFileName)
} }
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file // DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
@@ -79,24 +93,47 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
return nil return nil
} }
func writeJson(file string, obj interface{}, configDir string, configFileName string) error { func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error {
// Check context before expensive operations
if ctx.Err() != nil {
return fmt.Errorf("write json start: %w", ctx.Err())
}
// make it pretty // make it pretty
bs, err := json.MarshalIndent(obj, "", " ") bs, err := json.MarshalIndent(obj, "", " ")
if err != nil { if err != nil {
return err return fmt.Errorf("marshal: %w", err)
}
return writeBytes(ctx, file, err, configDir, configFileName, bs)
}
func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error {
if ctx.Err() != nil {
return fmt.Errorf("write bytes start: %w", ctx.Err())
} }
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
if err != nil { if err != nil {
return err return fmt.Errorf("create temp: %w", err)
} }
tempFileName := tempFile.Name() tempFileName := tempFile.Name()
// closing file ops as windows doesn't allow to move it
err = tempFile.Close() if deadline, ok := ctx.Deadline(); ok {
if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) {
log.Warnf("failed to set deadline: %v", err)
}
}
_, err = tempFile.Write(bs)
if err != nil { if err != nil {
return err _ = tempFile.Close()
return fmt.Errorf("write: %w", err)
}
if err = tempFile.Close(); err != nil {
return fmt.Errorf("close %s: %w", tempFileName, err)
} }
defer func() { defer func() {
@@ -106,14 +143,13 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st
} }
}() }()
err = os.WriteFile(tempFileName, bs, 0600) // Check context again
if err != nil { if ctx.Err() != nil {
return err return fmt.Errorf("after temp file: %w", ctx.Err())
} }
err = os.Rename(tempFileName, file) if err = os.Rename(tempFileName, file); err != nil {
if err != nil { return fmt.Errorf("move %s to %s: %w", tempFileName, file, err)
return err
} }
return nil return nil

View File

@@ -1,6 +1,7 @@
package util package util
import ( import (
"context"
"crypto/md5" "crypto/md5"
"encoding/hex" "encoding/hex"
"io" "io"
@@ -39,7 +40,7 @@ func TestConfigJSON(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
err := WriteJson(tmpDir+"/testconfig.json", tt.config) err := WriteJson(context.Background(), tmpDir+"/testconfig.json", tt.config)
require.NoError(t, err) require.NoError(t, err)
read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
@@ -73,7 +74,7 @@ func TestCopyFileContents(t *testing.T) {
src := tmpDir + "/copytest_src" src := tmpDir + "/copytest_src"
dst := tmpDir + "/copytest_dst" dst := tmpDir + "/copytest_dst"
err := WriteJson(src, tt.srcContent) err := WriteJson(context.Background(), src, tt.srcContent)
require.NoError(t, err) require.NoError(t, err)
err = CopyFileContents(src, dst) err = CopyFileContents(src, dst)
@@ -127,7 +128,7 @@ func TestHandleConfigFileWithoutFullPath(t *testing.T) {
_ = os.Remove(cfgFile) _ = os.Remove(cfgFile)
}() }()
err := WriteJson(cfgFile, tt.config) err := WriteJson(context.Background(), cfgFile, tt.config)
require.NoError(t, err) require.NoError(t, err)
read, err := ReadJson(cfgFile, &TestConfig{}) read, err := ReadJson(cfgFile, &TestConfig{})

View File

@@ -3,6 +3,9 @@ package grpc
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net" "net"
"os/user" "os/user"
"runtime" "runtime"
@@ -23,20 +26,22 @@ func WithCustomDialer() grpc.DialOption {
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
currentUser, err := user.Current() currentUser, err := user.Current()
if err != nil { if err != nil {
log.Fatalf("failed to get current user: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
} }
// the custom dialer requires root permissions which are not required for use cases run as non-root // the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" { if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{} dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr) return dialer.DialContext(ctx, "tcp", addr)
} }
} }
log.Debug("Using nbnet.NewDialer()")
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
log.Errorf("Failed to dial: %s", err) log.Errorf("Failed to dial: %s", err)
return nil, err return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
} }
return conn, nil return conn, nil
}) })

View File

@@ -69,7 +69,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
conn, err := d.Dialer.DialContext(ctx, network, address) conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil { if err != nil {
return nil, fmt.Errorf("dial: %w", err) return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
} }
// Wrap the connection in Conn to handle Close with hooks // Wrap the connection in Conn to handle Close with hooks

View File

@@ -4,9 +4,14 @@ package net
import ( import (
"fmt" "fmt"
"os"
"syscall" "syscall"
log "github.com/sirupsen/logrus"
) )
const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK"
// SetSocketMark sets the SO_MARK option on the given socket connection // SetSocketMark sets the SO_MARK option on the given socket connection
func SetSocketMark(conn syscall.Conn) error { func SetSocketMark(conn syscall.Conn) error {
sysconn, err := conn.SyscallConn() sysconn, err := conn.SyscallConn()
@@ -36,6 +41,13 @@ func SetRawSocketMark(conn syscall.RawConn) error {
func SetSocketOpt(fd int) error { func SetSocketOpt(fd int) error {
if CustomRoutingDisabled() { if CustomRoutingDisabled() {
log.Infof("Custom routing is disabled, skipping SO_MARK")
return nil
}
// Check for the new environment variable
if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" {
log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK")
return nil return nil
} }