mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-23 02:36:42 +00:00
Compare commits
26 Commits
test/seed-
...
debug-0.33
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3efa7a282a | ||
|
|
40551099b3 | ||
|
|
9db1932664 | ||
|
|
1bbabf70b0 | ||
|
|
aa575d6f44 | ||
|
|
f66bbcc54c | ||
|
|
5dd6a08ea6 | ||
|
|
eb5d0569ae | ||
|
|
52ea2e84e9 | ||
|
|
78fab877c0 | ||
|
|
65a94f695f | ||
|
|
ec543f89fb | ||
|
|
a7d5c52203 | ||
|
|
582bb58714 | ||
|
|
121dfda915 | ||
|
|
a1c5287b7c | ||
|
|
12f442439a | ||
|
|
d9b691b8a5 | ||
|
|
4aee3c9e33 | ||
|
|
44e799c687 | ||
|
|
be78efbd42 | ||
|
|
6886691213 | ||
|
|
b48afd92fd | ||
|
|
39329e12a1 | ||
|
|
20a5afc359 | ||
|
|
6cb697eed6 |
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.16"
|
||||
SIGN_PIPE_VER: "v0.0.17"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||
|
||||
@@ -19,6 +19,10 @@
|
||||
<br>
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://gurubase.io/g/netbird">
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
@@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
// persist early
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -197,7 +199,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
var chain *nftables.Chain
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||
chain = c
|
||||
break
|
||||
}
|
||||
@@ -274,7 +276,7 @@ func (m *Manager) resetNetbirdInputRules() error {
|
||||
|
||||
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||
rules, err := m.rConn.GetRules(c.Table, c)
|
||||
if err != nil {
|
||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||
@@ -349,7 +351,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: []byte(allowNetbirdInputRuleID),
|
||||
}
|
||||
|
||||
@@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
|
||||
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
@@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
||||
|
||||
// WriteOutConfig write put the prepared config to the given path
|
||||
func WriteOutConfig(path string, config *Config) error {
|
||||
return util.WriteJson(path, config)
|
||||
return util.WriteJson(context.Background(), path, config)
|
||||
}
|
||||
|
||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||
@@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
||||
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,7 +157,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
|
||||
engineCtx, cancel := context.WithCancel(c.ctx)
|
||||
defer func() {
|
||||
c.statusRecorder.MarkManagementDisconnected(state.err)
|
||||
_, err := state.Status()
|
||||
c.statusRecorder.MarkManagementDisconnected(err)
|
||||
c.statusRecorder.CleanLocalPeerState()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
@@ -323,12 +322,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
// persist dns state right away
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := s.stateManager.PersistState(ctx); err != nil {
|
||||
log.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
go func() {
|
||||
// persist dns state right away
|
||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||
log.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if s.searchDomainNotifier != nil {
|
||||
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
||||
@@ -533,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||
}
|
||||
|
||||
// persist dns state right away
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := s.stateManager.PersistState(ctx); err != nil {
|
||||
l.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
go func() {
|
||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||
l.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
|
||||
s.addHostRootZone()
|
||||
|
||||
@@ -782,7 +782,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Domains: []string{"customdomain.com"},
|
||||
Domains: []string{"google.com"},
|
||||
Primary: false,
|
||||
},
|
||||
},
|
||||
@@ -804,7 +804,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("invalid zone record: %v", err)
|
||||
}
|
||||
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
|
||||
_, err = resolver.LookupHost(context.Background(), "google.com")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -38,7 +39,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -171,7 +171,7 @@ type Engine struct {
|
||||
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
srWatcher *guard.SRWatcher
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -297,7 +297,7 @@ func (e *Engine) Stop() error {
|
||||
if err := e.stateManager.Stop(ctx); err != nil {
|
||||
return fmt.Errorf("failed to stop state manager: %w", err)
|
||||
}
|
||||
if err := e.stateManager.PersistState(ctx); err != nil {
|
||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
@@ -641,6 +641,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
}
|
||||
|
||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wireguard interface is not initialized")
|
||||
}
|
||||
|
||||
if e.wgInterface.Address().String() != conf.Address {
|
||||
oldAddr := e.wgInterface.Address().String()
|
||||
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
|
||||
@@ -1481,6 +1485,17 @@ func (e *Engine) stopDNSServer() {
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||
for _, check := range checks {
|
||||
sort.Slice(check.Files, func(i, j int) bool {
|
||||
return check.Files[i] < check.Files[j]
|
||||
})
|
||||
}
|
||||
for _, oCheck := range oChecks {
|
||||
sort.Slice(oCheck.Files, func(i, j int) bool {
|
||||
return oCheck.Files[i] < oCheck.Files[j]
|
||||
})
|
||||
}
|
||||
|
||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||
return slices.Equal(checks.Files, oChecks.Files)
|
||||
})
|
||||
|
||||
@@ -1006,6 +1006,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CheckFilesEqual(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputChecks1 []*mgmtProto.Checks
|
||||
inputChecks2 []*mgmtProto.Checks
|
||||
expectedBool bool
|
||||
}{
|
||||
{
|
||||
name: "Equal Files In Equal Order Should Return True",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedBool: true,
|
||||
},
|
||||
{
|
||||
name: "Equal Files In Reverse Order Should Return True",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile2",
|
||||
"testfile1",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedBool: true,
|
||||
},
|
||||
{
|
||||
name: "Unequal Files Should Return False",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile3",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedBool: false,
|
||||
},
|
||||
{
|
||||
name: "Compared With Empty Should Return False",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{},
|
||||
},
|
||||
},
|
||||
expectedBool: false,
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
|
||||
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
currentRoute: "route1",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "relayed routes with latency 0 should maintain previous choice",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: true,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
relayed: true,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "route1",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "p2p routes with latency 0 should maintain previous choice",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "route1",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "current route with bad score should be changed to route with better score",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
@@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// fill the test data with random routes
|
||||
for _, tc := range testCases {
|
||||
for i := 0; i < 50; i++ {
|
||||
dummyRoute := &route.Route{
|
||||
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
|
||||
Metric: route.MinMetric,
|
||||
Peer: fmt.Sprintf("dummy_p1_%d", i),
|
||||
}
|
||||
tc.existingRoutes[dummyRoute.ID] = dummyRoute
|
||||
}
|
||||
for i := 0; i < 50; i++ {
|
||||
dummyRoute := &route.Route{
|
||||
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
|
||||
Metric: route.MinMetric,
|
||||
Peer: fmt.Sprintf("dummy_p1_%d", i),
|
||||
}
|
||||
tc.existingRoutes[dummyRoute.ID] = dummyRoute
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
|
||||
dummyStatus := routerPeerStatus{
|
||||
connected: false,
|
||||
relayed: true,
|
||||
latency: 0,
|
||||
}
|
||||
tc.statuses[id] = dummyStatus
|
||||
}
|
||||
for i := 0; i < 50; i++ {
|
||||
id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
|
||||
dummyStatus := routerPeerStatus{
|
||||
connected: false,
|
||||
relayed: true,
|
||||
latency: 0,
|
||||
}
|
||||
tc.statuses[id] = dummyStatus
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
currentRoute := &route.Route{
|
||||
|
||||
@@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
|
||||
type Counter[Key comparable, I, O any] struct {
|
||||
// refCountMap keeps track of the reference Ref for keys
|
||||
refCountMap map[Key]Ref[O]
|
||||
refCountMu sync.Mutex
|
||||
mu sync.Mutex
|
||||
// idMap keeps track of the keys associated with an ID for removal
|
||||
idMap map[string][]Key
|
||||
idMu sync.Mutex
|
||||
add AddFunc[Key, I, O]
|
||||
remove RemoveFunc[Key, O]
|
||||
}
|
||||
@@ -75,10 +74,8 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
|
||||
func (rm *Counter[Key, I, O]) LoadData(
|
||||
existingCounter *Counter[Key, I, O],
|
||||
) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
rm.refCountMap = existingCounter.refCountMap
|
||||
rm.idMap = existingCounter.idMap
|
||||
@@ -87,8 +84,8 @@ func (rm *Counter[Key, I, O]) LoadData(
|
||||
// Get retrieves the current reference count and associated data for a key.
|
||||
// If the key doesn't exist, it returns a zero value Ref and false.
|
||||
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
ref, ok := rm.refCountMap[key]
|
||||
return ref, ok
|
||||
@@ -97,9 +94,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
||||
// Increment increments the reference count for the given key.
|
||||
// If this is the first reference to the key, the AddFunc is called.
|
||||
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
return rm.increment(key, in)
|
||||
}
|
||||
|
||||
func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
|
||||
ref := rm.refCountMap[key]
|
||||
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
|
||||
|
||||
@@ -126,10 +127,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
|
||||
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
|
||||
// If this is the first reference to the key, the AddFunc is called.
|
||||
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
ref, err := rm.Increment(key, in)
|
||||
ref, err := rm.increment(key, in)
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("with ID: %w", err)
|
||||
}
|
||||
@@ -141,9 +142,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O],
|
||||
// Decrement decrements the reference count for the given key.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
return rm.decrement(key)
|
||||
}
|
||||
|
||||
func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
|
||||
ref, ok := rm.refCountMap[key]
|
||||
if !ok {
|
||||
logCallerF("No reference found for key %v", key)
|
||||
@@ -168,12 +172,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
|
||||
// DecrementWithID decrements the reference count for all keys associated with the given ID.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, key := range rm.idMap[id] {
|
||||
if _, err := rm.Decrement(key); err != nil {
|
||||
if _, err := rm.decrement(key); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
@@ -184,10 +188,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
||||
|
||||
// Flush removes all references and calls RemoveFunc for each key.
|
||||
func (rm *Counter[Key, I, O]) Flush() error {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for key := range rm.refCountMap {
|
||||
@@ -206,10 +208,8 @@ func (rm *Counter[Key, I, O]) Flush() error {
|
||||
|
||||
// Clear removes all references without calling RemoveFunc.
|
||||
func (rm *Counter[Key, I, O]) Clear() {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
clear(rm.refCountMap)
|
||||
clear(rm.idMap)
|
||||
@@ -217,10 +217,8 @@ func (rm *Counter[Key, I, O]) Clear() {
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for Counter.
|
||||
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
return json.Marshal(struct {
|
||||
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||
|
||||
@@ -2,31 +2,28 @@ package systemops
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
|
||||
type ShutdownState struct {
|
||||
Counter *ExclusionCounter `json:"counter,omitempty"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
type ShutdownState ExclusionCounter
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "route_state"
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.Counter == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sysops := NewSysOps(nil, nil)
|
||||
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
|
||||
sysops.refCounter.LoadData(s.Counter)
|
||||
sysops.refCounter.LoadData((*ExclusionCounter)(s))
|
||||
|
||||
return sysops.refCounter.Flush()
|
||||
}
|
||||
|
||||
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
||||
return (*ExclusionCounter)(s).MarshalJSON()
|
||||
}
|
||||
|
||||
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
|
||||
return (*ExclusionCounter)(s).UnmarshalJSON(data)
|
||||
}
|
||||
|
||||
@@ -57,30 +57,19 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
|
||||
return nexthop, refcounter.ErrIgnore
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
|
||||
return nexthop, err
|
||||
},
|
||||
func(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
// remove from state even if we have trouble removing it from the route table
|
||||
// it could be already gone
|
||||
r.updateState(stateManager)
|
||||
|
||||
return r.removeFromRouteTable(prefix, nexthop)
|
||||
},
|
||||
r.removeFromRouteTable,
|
||||
)
|
||||
|
||||
r.refCounter = refCounter
|
||||
|
||||
return r.setupHooks(initAddresses)
|
||||
return r.setupHooks(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
// updateState updates state on every change so it will be persisted regularly
|
||||
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
|
||||
state := getState(stateManager)
|
||||
|
||||
state.Counter = r.refCounter
|
||||
|
||||
if err := stateManager.UpdateState(state); err != nil {
|
||||
if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -336,7 +325,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
||||
return r.removeFromRouteTable(prefix, nextHop)
|
||||
}
|
||||
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
@@ -347,6 +336,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
|
||||
return fmt.Errorf("adding route reference: %v", err)
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
|
||||
return nil
|
||||
}
|
||||
afterHook := func(connID nbnet.ConnectionID) error {
|
||||
@@ -354,6 +345,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -532,14 +525,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
|
||||
// Return true if the longest matching prefix is from vpnRoutes
|
||||
return isVpn, longestPrefix
|
||||
}
|
||||
|
||||
func getState(stateManager *statemanager.Manager) *ShutdownState {
|
||||
var shutdownState *ShutdownState
|
||||
if state := stateManager.GetState(shutdownState); state != nil {
|
||||
shutdownState = state.(*ShutdownState)
|
||||
} else {
|
||||
shutdownState = &ShutdownState{}
|
||||
}
|
||||
|
||||
return shutdownState
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ type ruleParams struct {
|
||||
|
||||
// isLegacy determines whether to use the legacy routing setup
|
||||
func isLegacy() bool {
|
||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
|
||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true"
|
||||
}
|
||||
|
||||
// setIsLegacy sets the legacy routing setup
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// State interface defines the methods that all state types must implement
|
||||
@@ -73,15 +74,15 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
if m.cancel == nil {
|
||||
return nil
|
||||
}
|
||||
m.cancel()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-m.done:
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-m.done:
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -178,25 +179,18 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
bs, err := marshalWithPanicRecovery(m.states)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal states: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
|
||||
start := time.Now()
|
||||
go func() {
|
||||
data, err := json.MarshalIndent(m.states, "", " ")
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("marshal states: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// nolint:gosec
|
||||
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
|
||||
done <- fmt.Errorf("write state file: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
done <- nil
|
||||
done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
|
||||
}()
|
||||
|
||||
select {
|
||||
@@ -208,7 +202,7 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty))
|
||||
log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
|
||||
|
||||
clear(m.dirty)
|
||||
|
||||
@@ -296,3 +290,19 @@ func (m *Manager) PerformCleanup() error {
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func marshalWithPanicRecovery(v any) ([]byte, error) {
|
||||
var bs []byte
|
||||
var err error
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic during marshal: %v", r)
|
||||
}
|
||||
}()
|
||||
bs, err = json.Marshal(v)
|
||||
}()
|
||||
|
||||
return bs, err
|
||||
}
|
||||
|
||||
@@ -4,32 +4,20 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// GetDefaultStatePath returns the path to the state file based on the operating system
|
||||
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist.
|
||||
// It returns an empty string if the path cannot be determined.
|
||||
func GetDefaultStatePath() string {
|
||||
var path string
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
|
||||
return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
|
||||
case "darwin", "linux":
|
||||
path = "/var/lib/netbird/state.json"
|
||||
return "/var/lib/netbird/state.json"
|
||||
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||
path = "/var/db/netbird/state.json"
|
||||
// ios/android don't need state
|
||||
default:
|
||||
return ""
|
||||
return "/var/db/netbird/state.json"
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
@@ -1,13 +1,41 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// only if env is not set
|
||||
if os.Getenv("NB_LOG_LEVEL") == "" {
|
||||
if err := os.Setenv("NB_LOG_LEVEL", "debug"); err != nil {
|
||||
log.Errorf("Failed setting log-level: %v", err)
|
||||
}
|
||||
}
|
||||
if err := os.Setenv("NB_LOG_MAX_SIZE_MB", "100"); err != nil {
|
||||
log.Errorf("Failed setting log-size: %v", err)
|
||||
}
|
||||
if err := os.Setenv("NB_WINDOWS_PANIC_LOG", filepath.Join(os.Getenv("ProgramData"), "netbird", "netbird.err")); err != nil {
|
||||
log.Errorf("Failed setting panic log path: %v", err)
|
||||
}
|
||||
|
||||
go startPprofServer()
|
||||
|
||||
if err := cmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func startPprofServer() {
|
||||
pprofAddr := "localhost:6969"
|
||||
log.Infof("Starting pprof debugging server on %s", pprofAddr)
|
||||
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
|
||||
log.Infof("pprof server failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,6 +120,35 @@ func (s *Server) Start() error {
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if statusResp, err := s.Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true}); err != nil {
|
||||
log.Infof("Error getting status: %v", err)
|
||||
} else if statusResp.FullStatus != nil {
|
||||
log.Infof("Status --------")
|
||||
for _, peer := range statusResp.FullStatus.Peers {
|
||||
log.Infof("[Peer Connection] Name: %s, IP: %s, Key: %s, Connection Status: %s, Relayed: %v, RelayedAddress: %v, Last WireGuard Handshake: %v",
|
||||
peer.Fqdn,
|
||||
peer.IP,
|
||||
peer.PubKey,
|
||||
peer.ConnStatus,
|
||||
peer.Relayed,
|
||||
peer.RelayAddress,
|
||||
peer.LastWireguardHandshake.AsTime().Format("15:04:05"),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// if configuration exists, we just start connections. if is new config we skip and set status NeedsLogin
|
||||
// on failure we return error to retry
|
||||
config, err := internal.UpdateConfig(s.latestConfigInput)
|
||||
|
||||
@@ -110,7 +110,6 @@ type AccountManager interface {
|
||||
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
|
||||
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
|
||||
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
|
||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
||||
@@ -966,7 +965,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro
|
||||
}
|
||||
|
||||
// UserGroupsAddToPeers adds groups to all peers of user
|
||||
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
|
||||
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string {
|
||||
groupUpdates := make(map[string][]string)
|
||||
|
||||
userPeers := make(map[string]struct{})
|
||||
for pid, peer := range a.Peers {
|
||||
if peer.UserID == userID {
|
||||
@@ -980,6 +981,8 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
|
||||
continue
|
||||
}
|
||||
|
||||
oldPeers := group.Peers
|
||||
|
||||
groupPeers := make(map[string]struct{})
|
||||
for _, pid := range group.Peers {
|
||||
groupPeers[pid] = struct{}{}
|
||||
@@ -993,16 +996,25 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
|
||||
for pid := range groupPeers {
|
||||
group.Peers = append(group.Peers, pid)
|
||||
}
|
||||
|
||||
groupUpdates[gid] = difference(group.Peers, oldPeers)
|
||||
}
|
||||
|
||||
return groupUpdates
|
||||
}
|
||||
|
||||
// UserGroupsRemoveFromPeers removes groups from all peers of user
|
||||
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
||||
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string {
|
||||
groupUpdates := make(map[string][]string)
|
||||
|
||||
for _, gid := range groups {
|
||||
group, ok := a.Groups[gid]
|
||||
if !ok || group.Name == "All" {
|
||||
continue
|
||||
}
|
||||
|
||||
oldPeers := group.Peers
|
||||
|
||||
update := make([]string, 0, len(group.Peers))
|
||||
for _, pid := range group.Peers {
|
||||
peer, ok := a.Peers[pid]
|
||||
@@ -1014,7 +1026,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
||||
}
|
||||
}
|
||||
group.Peers = update
|
||||
groupUpdates[gid] = difference(oldPeers, group.Peers)
|
||||
}
|
||||
|
||||
return groupUpdates
|
||||
}
|
||||
|
||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||
@@ -1176,6 +1191,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("groups propagation failed: %w", err)
|
||||
}
|
||||
|
||||
updatedAccount := account.UpdateSettings(newSettings)
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
@@ -1186,21 +1206,39 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return updatedAccount, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
||||
event := activity.AccountPeerInactivityExpirationEnabled
|
||||
if !newSettings.PeerInactivityExpirationEnabled {
|
||||
event = activity.AccountPeerInactivityExpirationDisabled
|
||||
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||
func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||
if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled {
|
||||
if newSettings.GroupsPropagationEnabled {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil)
|
||||
// Todo: retroactively add user groups to all peers
|
||||
} else {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||
}
|
||||
|
||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||
|
||||
if newSettings.PeerInactivityExpirationEnabled {
|
||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||
oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration
|
||||
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
} else {
|
||||
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
||||
event := activity.AccountPeerInactivityExpirationEnabled
|
||||
if !newSettings.PeerInactivityExpirationEnabled {
|
||||
event = activity.AccountPeerInactivityExpirationDisabled
|
||||
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||
} else {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1435,7 +1473,7 @@ func isNil(i idp.Manager) bool {
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||
if !isNil(am.idpManager) {
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -2029,7 +2067,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
return fmt.Errorf("error getting user: %w", err)
|
||||
}
|
||||
|
||||
groups, err := transaction.GetAccountGroups(ctx, accountID)
|
||||
groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
@@ -2059,7 +2097,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
|
||||
// Propagate changes to peers if group propagation is enabled
|
||||
if settings.GroupsPropagationEnabled {
|
||||
groups, err = transaction.GetAccountGroups(ctx, accountID)
|
||||
groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
@@ -2083,7 +2121,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf("error incrementing network serial: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -2101,7 +2139,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
|
||||
for _, g := range addNewGroups {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||
} else {
|
||||
@@ -2114,7 +2152,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
|
||||
for _, g := range removeOldGroups {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||
} else {
|
||||
@@ -2127,14 +2165,19 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
|
||||
if settings.GroupsPropagationEnabled {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting account: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
|
||||
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2290,12 +2333,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, status.NewGetAccountError(err)
|
||||
}
|
||||
|
||||
peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account)
|
||||
@@ -2314,12 +2357,12 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
return status.NewGetAccountError(err)
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -2335,6 +2378,9 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
|
||||
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||
defer unlockPeer()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -2398,12 +2444,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context,
|
||||
|
||||
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
|
||||
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
|
||||
updatedAccount, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
|
||||
return
|
||||
}
|
||||
am.updateAccountPeers(ctx, updatedAccount)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
|
||||
@@ -1413,11 +1413,13 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
group := group.Group{
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
}
|
||||
})
|
||||
|
||||
require.NoError(t, err, "failed to save group")
|
||||
|
||||
policy := Policy{
|
||||
Enabled: true,
|
||||
@@ -1460,7 +1462,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
|
||||
if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil {
|
||||
t.Errorf("delete group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -2714,7 +2716,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
@@ -2734,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
@@ -2773,7 +2775,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, groups, 3, "new group3 should be added")
|
||||
|
||||
|
||||
@@ -148,6 +148,9 @@ const (
|
||||
AccountPeerInactivityExpirationDurationUpdated Activity = 67
|
||||
|
||||
SetupKeyDeleted Activity = 68
|
||||
|
||||
UserGroupPropagationEnabled Activity = 69
|
||||
UserGroupPropagationDisabled Activity = 70
|
||||
)
|
||||
|
||||
var activityMap = map[Activity]Code{
|
||||
@@ -222,6 +225,9 @@ var activityMap = map[Activity]Code{
|
||||
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
|
||||
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
|
||||
SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"},
|
||||
|
||||
UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"},
|
||||
UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
@@ -146,7 +146,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
||||
// It is recommended to call it with locking FileStore.mux
|
||||
func (s *FileStore) persist(ctx context.Context, file string) error {
|
||||
start := time.Now()
|
||||
err := util.WriteJson(file, s)
|
||||
err := util.WriteJson(context.Background(), file, s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -6,11 +6,12 @@ import (
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@@ -27,18 +28,17 @@ func (e *GroupLinkError) Error() string {
|
||||
|
||||
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
|
||||
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
|
||||
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -49,8 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
|
||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
}
|
||||
|
||||
// GetAllGroups returns all groups in an account
|
||||
@@ -58,13 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.Store.GetAccountGroups(ctx, accountID)
|
||||
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
|
||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
|
||||
}
|
||||
|
||||
// SaveGroup object of the peers
|
||||
@@ -77,79 +75,74 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
|
||||
// SaveGroups adds new groups to the account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var groupsToSave []*nbgroup.Group
|
||||
var updateAccountPeers bool
|
||||
|
||||
for _, newGroup := range newGroups {
|
||||
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
||||
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
||||
}
|
||||
|
||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||
existingGroup, err := account.FindGroupByName(newGroup.Name)
|
||||
if err != nil {
|
||||
s, ok := status.FromError(err)
|
||||
if !ok || s.ErrorType != status.NotFound {
|
||||
return err
|
||||
}
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Avoid duplicate groups only for the API issued groups.
|
||||
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
|
||||
if existingGroup != nil {
|
||||
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
|
||||
}
|
||||
newGroup.AccountID = accountID
|
||||
groupsToSave = append(groupsToSave, newGroup)
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
|
||||
newGroup.ID = xid.New().String()
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
if account.Peers[peerID] == nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||
}
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldGroup := account.Groups[newGroup.ID]
|
||||
account.Groups[newGroup.ID] = newGroup
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
}
|
||||
|
||||
newGroupIDs := make([]string, 0, len(newGroups))
|
||||
for _, newGroup := range newGroups {
|
||||
newGroupIDs = append(newGroupIDs, newGroup.ID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if areGroupChangesAffectPeers(account, newGroupIDs) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareGroupEvents prepares a list of event functions to be stored.
|
||||
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
|
||||
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
addedPeers := make([]string, 0)
|
||||
removedPeers := make([]string, 0)
|
||||
|
||||
if oldGroup != nil {
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
|
||||
if err == nil && oldGroup != nil {
|
||||
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
||||
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
||||
} else {
|
||||
@@ -159,35 +152,42 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
|
||||
})
|
||||
}
|
||||
|
||||
for _, p := range addedPeers {
|
||||
peer := account.Peers[p]
|
||||
if peer == nil {
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
modifiedPeers := slices.Concat(addedPeers, removedPeers)
|
||||
peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, peerID := range addedPeers {
|
||||
peer, ok := peers[peerID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID)
|
||||
continue
|
||||
}
|
||||
peerCopy := peer // copy to avoid closure issues
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
|
||||
map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
|
||||
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
|
||||
})
|
||||
meta := map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID,
|
||||
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
}
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta)
|
||||
})
|
||||
}
|
||||
|
||||
for _, p := range removedPeers {
|
||||
peer := account.Peers[p]
|
||||
if peer == nil {
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
for _, peerID := range removedPeers {
|
||||
peer, ok := peers[peerID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID)
|
||||
continue
|
||||
}
|
||||
peerCopy := peer // copy to avoid closure issues
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
|
||||
map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
|
||||
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
|
||||
})
|
||||
meta := map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID,
|
||||
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
}
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -210,42 +210,10 @@ func difference(a, b []string) []string {
|
||||
}
|
||||
|
||||
// DeleteGroup object of the peers.
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if allGroup.ID == groupID {
|
||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||
}
|
||||
|
||||
if err = validateDeleteGroup(account, group, userId); err != nil {
|
||||
return err
|
||||
}
|
||||
delete(account.Groups, groupID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
|
||||
|
||||
return nil
|
||||
return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
|
||||
}
|
||||
|
||||
// DeleteGroups deletes groups from an account.
|
||||
@@ -254,93 +222,94 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
||||
//
|
||||
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
|
||||
// Errors are collected and returned at the end.
|
||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var allErrors error
|
||||
var groupIDsToDelete []string
|
||||
var deletedGroups []*nbgroup.Group
|
||||
|
||||
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
continue
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
groupIDsToDelete = append(groupIDsToDelete, groupID)
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
}
|
||||
|
||||
if err := validateDeleteGroup(account, group, userId); err != nil {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
|
||||
continue
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(account.Groups, groupID)
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, g := range deletedGroups {
|
||||
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
for _, group := range deletedGroups {
|
||||
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
|
||||
}
|
||||
|
||||
return allErrors
|
||||
}
|
||||
|
||||
// ListGroups objects of the peers
|
||||
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groups := make([]*nbgroup.Group, 0, len(account.Groups))
|
||||
for _, item := range account.Groups {
|
||||
groups = append(groups, item)
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
var group *nbgroup.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated := group.AddPeer(peerID); !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||
}
|
||||
|
||||
add := true
|
||||
for _, itemID := range group.Peers {
|
||||
if itemID == peerID {
|
||||
add = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if add {
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -351,90 +320,162 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
var group *nbgroup.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated := group.RemovePeer(peerID); !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
for i, itemID := range group.Peers {
|
||||
if itemID == peerID {
|
||||
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
|
||||
// validateNewGroup validates the new group for existence and required fields.
|
||||
func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
|
||||
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
||||
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
||||
}
|
||||
|
||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||
existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Prevent duplicate groups for API-issued groups.
|
||||
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
|
||||
if existingGroup != nil {
|
||||
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
|
||||
}
|
||||
|
||||
newGroup.ID = xid.New().String()
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
_, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
|
||||
// disable a deleting integration group if the initiator is not an admin service user
|
||||
if group.Issued == nbgroup.GroupIssuedIntegration {
|
||||
executingUser := account.Users[userID]
|
||||
if executingUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
||||
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
|
||||
}
|
||||
}
|
||||
|
||||
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
|
||||
if group.IsGroupAll() {
|
||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||
}
|
||||
|
||||
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
||||
}
|
||||
|
||||
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
|
||||
if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"name server groups", linkedDns.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
|
||||
if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"policy", linkedPolicy.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
|
||||
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"setup key", linkedSetupKey.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
|
||||
if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"user", linkedUser.Id}
|
||||
}
|
||||
|
||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
|
||||
return checkGroupLinkedToSettings(ctx, transaction, group)
|
||||
}
|
||||
|
||||
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
|
||||
func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error {
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
|
||||
return &GroupLinkError{"disabled DNS management groups", group.Name}
|
||||
}
|
||||
|
||||
if account.Settings.Extra != nil {
|
||||
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
|
||||
return &GroupLinkError{"integrated validator", group.Name}
|
||||
}
|
||||
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
|
||||
return &GroupLinkError{"integrated validator", group.Name}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
|
||||
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
|
||||
func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
|
||||
routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, r := range routes {
|
||||
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
|
||||
return true, r
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
|
||||
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
|
||||
func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, policy := range policies {
|
||||
for _, rule := range policy.Rules {
|
||||
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
|
||||
@@ -446,7 +487,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
|
||||
}
|
||||
|
||||
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
|
||||
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||
func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, dns := range nameServerGroups {
|
||||
for _, g := range dns.Groups {
|
||||
if g == groupID {
|
||||
@@ -454,11 +501,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
|
||||
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
|
||||
func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
|
||||
setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, setupKey := range setupKeys {
|
||||
if slices.Contains(setupKey.AutoGroups, groupID) {
|
||||
return true, setupKey
|
||||
@@ -468,7 +522,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
|
||||
}
|
||||
|
||||
// isGroupLinkedToUser checks if a group is linked to any user in the account.
|
||||
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
||||
func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
|
||||
users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
if slices.Contains(user.AutoGroups, groupID) {
|
||||
return true, user
|
||||
@@ -477,6 +537,35 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
|
||||
func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
||||
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
@@ -486,22 +575,3 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -49,3 +49,35 @@ func (g *Group) Copy() *Group {
|
||||
func (g *Group) HasPeers() bool {
|
||||
return len(g.Peers) > 0
|
||||
}
|
||||
|
||||
// IsGroupAll checks if the group is a default "All" group.
|
||||
func (g *Group) IsGroupAll() bool {
|
||||
return g.Name == "All"
|
||||
}
|
||||
|
||||
// AddPeer adds peerID to Peers if not present, returning true if added.
|
||||
func (g *Group) AddPeer(peerID string) bool {
|
||||
if peerID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, itemID := range g.Peers {
|
||||
if itemID == peerID {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
g.Peers = append(g.Peers, peerID)
|
||||
return true
|
||||
}
|
||||
|
||||
// RemovePeer removes peerID from Peers if present, returning true if removed.
|
||||
func (g *Group) RemovePeer(peerID string) bool {
|
||||
for i, itemID := range g.Peers {
|
||||
if itemID == peerID {
|
||||
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
90
management/server/group/group_test.go
Normal file
90
management/server/group/group_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAddPeer(t *testing.T) {
|
||||
t.Run("add new peer to empty slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{}}
|
||||
peerID := "peer1"
|
||||
assert.True(t, group.AddPeer(peerID))
|
||||
assert.Contains(t, group.Peers, peerID)
|
||||
})
|
||||
|
||||
t.Run("add new peer to nil slice", func(t *testing.T) {
|
||||
group := &Group{Peers: nil}
|
||||
peerID := "peer1"
|
||||
assert.True(t, group.AddPeer(peerID))
|
||||
assert.Contains(t, group.Peers, peerID)
|
||||
})
|
||||
|
||||
t.Run("add new peer to non-empty slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := "peer3"
|
||||
assert.True(t, group.AddPeer(peerID))
|
||||
assert.Contains(t, group.Peers, peerID)
|
||||
})
|
||||
|
||||
t.Run("add duplicate peer", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := "peer1"
|
||||
assert.False(t, group.AddPeer(peerID))
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
|
||||
t.Run("add empty peer", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := ""
|
||||
assert.False(t, group.AddPeer(peerID))
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRemovePeer(t *testing.T) {
|
||||
t.Run("remove existing peer from slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2", "peer3"}}
|
||||
peerID := "peer2"
|
||||
assert.True(t, group.RemovePeer(peerID))
|
||||
assert.NotContains(t, group.Peers, peerID)
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
|
||||
t.Run("remove peer from empty slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{}}
|
||||
peerID := "peer1"
|
||||
assert.False(t, group.RemovePeer(peerID))
|
||||
assert.Equal(t, 0, len(group.Peers))
|
||||
})
|
||||
|
||||
t.Run("remove peer from nil slice", func(t *testing.T) {
|
||||
group := &Group{Peers: nil}
|
||||
peerID := "peer1"
|
||||
assert.False(t, group.RemovePeer(peerID))
|
||||
assert.Nil(t, group.Peers)
|
||||
})
|
||||
|
||||
t.Run("remove non-existent peer", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := "peer3"
|
||||
assert.False(t, group.RemovePeer(peerID))
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
|
||||
t.Run("remove peer from single-item slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1"}}
|
||||
peerID := "peer1"
|
||||
assert.True(t, group.RemovePeer(peerID))
|
||||
assert.Equal(t, 0, len(group.Peers))
|
||||
assert.NotContains(t, group.Peers, peerID)
|
||||
})
|
||||
|
||||
t.Run("remove empty peer", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := ""
|
||||
assert.False(t, group.RemovePeer(peerID))
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
}
|
||||
@@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
|
||||
{
|
||||
name: "delete non-existent group",
|
||||
groupIDs: []string{"non-existent-group"},
|
||||
expectedDeleted: []string{"non-existent-group"},
|
||||
expectedReasons: []string{"group: non-existent-group not found"},
|
||||
},
|
||||
{
|
||||
name: "delete multiple groups with mixed results",
|
||||
|
||||
@@ -180,6 +180,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
@@ -207,6 +208,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
||||
for {
|
||||
select {
|
||||
// condition when there are some updates
|
||||
@@ -260,10 +262,15 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p
|
||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||
defer unlock()
|
||||
|
||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
||||
}
|
||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
s.secretsManager.CancelRefresh(peer.ID)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
|
||||
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
|
||||
}
|
||||
|
||||
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||
|
||||
@@ -439,17 +439,13 @@ components:
|
||||
example: 5
|
||||
required:
|
||||
- accessible_peers_count
|
||||
SetupKey:
|
||||
SetupKeyBase:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
description: Setup Key ID
|
||||
type: string
|
||||
example: 2531583362
|
||||
key:
|
||||
description: Setup Key value
|
||||
type: string
|
||||
example: A616097E-FCF0-48FA-9354-CA4A61142761
|
||||
name:
|
||||
description: Setup key name identifier
|
||||
type: string
|
||||
@@ -518,22 +514,31 @@ components:
|
||||
- updated_at
|
||||
- usage_limit
|
||||
- ephemeral
|
||||
SetupKeyClear:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/SetupKeyBase'
|
||||
- type: object
|
||||
properties:
|
||||
key:
|
||||
description: Setup Key as plain text
|
||||
type: string
|
||||
example: A616097E-FCF0-48FA-9354-CA4A61142761
|
||||
required:
|
||||
- key
|
||||
SetupKey:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/SetupKeyBase'
|
||||
- type: object
|
||||
properties:
|
||||
key:
|
||||
description: Setup Key as secret
|
||||
type: string
|
||||
example: A6160****
|
||||
required:
|
||||
- key
|
||||
SetupKeyRequest:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: Setup Key name
|
||||
type: string
|
||||
example: Default key
|
||||
type:
|
||||
description: Setup key type, one-off for single time usage and reusable
|
||||
type: string
|
||||
example: reusable
|
||||
expires_in:
|
||||
description: Expiration time in seconds, 0 will mean the key never expires
|
||||
type: integer
|
||||
minimum: 0
|
||||
example: 86400
|
||||
revoked:
|
||||
description: Setup key revocation status
|
||||
type: boolean
|
||||
@@ -544,21 +549,9 @@ components:
|
||||
items:
|
||||
type: string
|
||||
example: "ch8i4ug6lnn4g9hqv7m0"
|
||||
usage_limit:
|
||||
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
type: integer
|
||||
example: 0
|
||||
ephemeral:
|
||||
description: Indicate that the peer will be ephemeral or not
|
||||
type: boolean
|
||||
example: true
|
||||
required:
|
||||
- name
|
||||
- type
|
||||
- expires_in
|
||||
- revoked
|
||||
- auto_groups
|
||||
- usage_limit
|
||||
CreateSetupKeyRequest:
|
||||
type: object
|
||||
properties:
|
||||
@@ -1943,7 +1936,7 @@ paths:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/SetupKey'
|
||||
$ref: '#/components/schemas/SetupKeyClear'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
|
||||
@@ -1062,7 +1062,94 @@ type SetupKey struct {
|
||||
// Id Setup Key ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// Key Setup Key value
|
||||
// Key Setup Key as secret
|
||||
Key string `json:"key"`
|
||||
|
||||
// LastUsed Setup key last usage date
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
|
||||
// Name Setup key name identifier
|
||||
Name string `json:"name"`
|
||||
|
||||
// Revoked Setup key revocation status
|
||||
Revoked bool `json:"revoked"`
|
||||
|
||||
// State Setup key status, "valid", "overused","expired" or "revoked"
|
||||
State string `json:"state"`
|
||||
|
||||
// Type Setup key type, one-off for single time usage and reusable
|
||||
Type string `json:"type"`
|
||||
|
||||
// UpdatedAt Setup key last update date
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
UsageLimit int `json:"usage_limit"`
|
||||
|
||||
// UsedTimes Usage count of setup key
|
||||
UsedTimes int `json:"used_times"`
|
||||
|
||||
// Valid Setup key validity status
|
||||
Valid bool `json:"valid"`
|
||||
}
|
||||
|
||||
// SetupKeyBase defines model for SetupKeyBase.
|
||||
type SetupKeyBase struct {
|
||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||
Ephemeral bool `json:"ephemeral"`
|
||||
|
||||
// Expires Setup Key expiration date
|
||||
Expires time.Time `json:"expires"`
|
||||
|
||||
// Id Setup Key ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// LastUsed Setup key last usage date
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
|
||||
// Name Setup key name identifier
|
||||
Name string `json:"name"`
|
||||
|
||||
// Revoked Setup key revocation status
|
||||
Revoked bool `json:"revoked"`
|
||||
|
||||
// State Setup key status, "valid", "overused","expired" or "revoked"
|
||||
State string `json:"state"`
|
||||
|
||||
// Type Setup key type, one-off for single time usage and reusable
|
||||
Type string `json:"type"`
|
||||
|
||||
// UpdatedAt Setup key last update date
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
UsageLimit int `json:"usage_limit"`
|
||||
|
||||
// UsedTimes Usage count of setup key
|
||||
UsedTimes int `json:"used_times"`
|
||||
|
||||
// Valid Setup key validity status
|
||||
Valid bool `json:"valid"`
|
||||
}
|
||||
|
||||
// SetupKeyClear defines model for SetupKeyClear.
|
||||
type SetupKeyClear struct {
|
||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||
Ephemeral bool `json:"ephemeral"`
|
||||
|
||||
// Expires Setup Key expiration date
|
||||
Expires time.Time `json:"expires"`
|
||||
|
||||
// Id Setup Key ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// Key Setup Key as plain text
|
||||
Key string `json:"key"`
|
||||
|
||||
// LastUsed Setup key last usage date
|
||||
@@ -1098,23 +1185,8 @@ type SetupKeyRequest struct {
|
||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||
Ephemeral *bool `json:"ephemeral,omitempty"`
|
||||
|
||||
// ExpiresIn Expiration time in seconds, 0 will mean the key never expires
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
|
||||
// Name Setup Key name
|
||||
Name string `json:"name"`
|
||||
|
||||
// Revoked Setup key revocation status
|
||||
Revoked bool `json:"revoked"`
|
||||
|
||||
// Type Setup key type, one-off for single time usage and reusable
|
||||
Type string `json:"type"`
|
||||
|
||||
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
UsageLimit int `json:"usage_limit"`
|
||||
}
|
||||
|
||||
// User defines model for User.
|
||||
|
||||
@@ -184,14 +184,26 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
|
||||
for _, peer := range account.Peers {
|
||||
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
groupsMap := map[string]*nbgroup.Group{}
|
||||
groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
for _, group := range groups {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
||||
for _, peer := range peers {
|
||||
peerToReturn, err := h.checkPeerStatus(peer)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID)
|
||||
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
|
||||
}
|
||||
@@ -304,7 +316,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
|
||||
}
|
||||
|
||||
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
|
||||
var groupsInfo []api.GroupMinimum
|
||||
groupsInfo := []api.GroupMinimum{}
|
||||
groupsChecked := make(map[string]struct{})
|
||||
for _, group := range groups {
|
||||
_, ok := groupsChecked[group.ID]
|
||||
|
||||
@@ -137,11 +137,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.AutoGroups == nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
|
||||
return
|
||||
@@ -150,7 +145,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
newKey := &server.SetupKey{}
|
||||
newKey.AutoGroups = req.AutoGroups
|
||||
newKey.Revoked = req.Revoked
|
||||
newKey.Name = req.Name
|
||||
newKey.Id = keyID
|
||||
|
||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
|
||||
|
||||
@@ -52,25 +52,22 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
|
||||
return am.Store.SaveAccount(ctx, a)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
|
||||
if len(groups) == 0 {
|
||||
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
accountsGroups, err := am.ListGroups(ctx, accountId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, group := range groups {
|
||||
var found bool
|
||||
for _, accountGroup := range accountsGroups {
|
||||
if accountGroup.ID == group {
|
||||
found = true
|
||||
break
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
_, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false, nil
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
|
||||
@@ -45,7 +45,6 @@ type MockAccountManager struct {
|
||||
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
|
||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
|
||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||
@@ -354,14 +353,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
|
||||
}
|
||||
|
||||
// ListGroups mock implementation of ListGroups from server.AccountManager interface
|
||||
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
|
||||
if am.ListGroupsFunc != nil {
|
||||
return am.ListGroupsFunc(ctx, accountID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
|
||||
}
|
||||
|
||||
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
|
||||
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
if am.GroupAddPeerFunc != nil {
|
||||
|
||||
@@ -71,7 +71,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
}
|
||||
|
||||
if anyGroupHasPeers(account, newNSGroup.Groups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||
|
||||
@@ -106,7 +106,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||
|
||||
@@ -136,7 +136,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
}
|
||||
|
||||
if anyGroupHasPeers(account, nsGroup.Groups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||
|
||||
|
||||
@@ -110,14 +110,16 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error {
|
||||
peer, err := account.FindPeerByPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to find peer by pub key: %w", err)
|
||||
}
|
||||
|
||||
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to update peer status and location: %w", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("mark peer %s connected: %t", peer.ID, connected)
|
||||
|
||||
if peer.AddedWithSSOLogin() {
|
||||
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||
@@ -131,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
if expired {
|
||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||
// the expired one. Here we notify them that connection is now allowed again.
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, account.Id)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -166,9 +168,11 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
|
||||
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
|
||||
|
||||
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, fmt.Errorf("failed to save peer status: %w", err)
|
||||
}
|
||||
|
||||
return oldStatus.LoginExpired, nil
|
||||
@@ -267,7 +271,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
if peerLabelUpdated || requiresPeerUpdates {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
@@ -331,7 +335,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers := isPeerInActiveGroup(account, peerID)
|
||||
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.deletePeers(ctx, account, []string{peerID}, userID)
|
||||
if err != nil {
|
||||
@@ -344,7 +351,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -551,7 +558,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return fmt.Errorf("failed to add peer to account: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
@@ -587,17 +594,22 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
|
||||
return nil, nil, nil, status.NewGetAccountError(err)
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
|
||||
}
|
||||
|
||||
groupsToAdd = append(groupsToAdd, allGroup.ID)
|
||||
if areGroupChangesAffectPeers(account, groupsToAdd) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if newGroupsAffectsPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
@@ -640,7 +652,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
if peer.UserID != "" {
|
||||
user, err := account.FindUser(peer.UserID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
err = checkIfPeerOwnerIsBlocked(peer, user)
|
||||
@@ -655,19 +667,22 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
|
||||
updated := peer.UpdateMetaIfNew(sync.Meta)
|
||||
if updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
||||
account.Peers[peer.ID] = peer
|
||||
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
|
||||
err = am.Store.SavePeer(ctx, account.Id, peer)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err)
|
||||
}
|
||||
|
||||
if sync.UpdateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, account.Id)
|
||||
}
|
||||
}
|
||||
|
||||
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, fmt.Errorf("failed to validate peer: %w", err)
|
||||
}
|
||||
|
||||
var postureChecks []*posture.Checks
|
||||
@@ -680,12 +695,12 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
}
|
||||
|
||||
if isStatusChanged {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, account.Id)
|
||||
}
|
||||
|
||||
validPeersMap, err := am.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
|
||||
}
|
||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
||||
|
||||
@@ -765,7 +780,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
}
|
||||
}
|
||||
|
||||
groups, err := am.Store.GetAccountGroups(ctx, accountID)
|
||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -787,6 +802,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
|
||||
updated := peer.UpdateMetaIfNew(login.Meta)
|
||||
if updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
||||
shouldStorePeer = true
|
||||
}
|
||||
|
||||
@@ -811,7 +827,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
}
|
||||
|
||||
if updateRemotePeers || isStatusChanged {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer)
|
||||
@@ -974,7 +990,13 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
||||
|
||||
// updateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
|
||||
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
if am.metrics != nil {
|
||||
@@ -1028,12 +1050,12 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
||||
|
||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||
// in an active DNS, route, or ACL configuration.
|
||||
func isPeerInActiveGroup(account *Account, peerID string) bool {
|
||||
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) {
|
||||
peerGroupIDs := make([]string, 0)
|
||||
for _, group := range account.Groups {
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
peerGroupIDs = append(peerGroupIDs, group.ID)
|
||||
}
|
||||
}
|
||||
return areGroupChangesAffectPeers(account, peerGroupIDs)
|
||||
return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs)
|
||||
}
|
||||
|
||||
@@ -877,7 +877,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
start := time.Now()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.updateAccountPeers(ctx, account)
|
||||
manager.updateAccountPeers(ctx, account.Id)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
@@ -377,7 +377,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -406,7 +406,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
|
||||
if anyGroupHasPeers(account, policy.ruleGroups()) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -69,7 +69,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||
|
||||
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -238,7 +238,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
}
|
||||
|
||||
if isRouteChangeAffectPeers(account, &newRoute) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||
@@ -324,7 +324,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
}
|
||||
|
||||
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
@@ -356,7 +356,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
||||
|
||||
if isRouteChangeAffectPeers(account, routy) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1091,7 +1091,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
|
||||
|
||||
groups, err := am.ListGroups(context.Background(), account.Id)
|
||||
groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id)
|
||||
require.NoError(t, err)
|
||||
var groupHA1, groupHA2 *nbgroup.Group
|
||||
for _, group := range groups {
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -229,32 +229,43 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil {
|
||||
return nil, err
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var setupKey *SetupKey
|
||||
var plainKey string
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
|
||||
setupKey.AccountID = accountID
|
||||
|
||||
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed adding account key")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta())
|
||||
|
||||
for _, g := range setupKey.AutoGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name})
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
|
||||
}
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
// for the creation return the plain key to the caller
|
||||
@@ -266,45 +277,61 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
// SaveSetupKey saves the provided SetupKey to the database overriding the existing one.
|
||||
// Due to the unique nature of a SetupKey certain properties must not be overwritten
|
||||
// (e.g. the key itself, creation date, ID, etc).
|
||||
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
|
||||
// These properties are overwritten: AutoGroups, Revoked (only from false to true), and the UpdatedAt. The rest is copied from the existing key.
|
||||
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if keyToSave == nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var oldKey *SetupKey
|
||||
for _, key := range account.SetupKeys {
|
||||
if key.Id == keyToSave.Id {
|
||||
oldKey = key.Copy()
|
||||
break
|
||||
var newKey *SetupKey
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if oldKey == nil {
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
|
||||
if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// only auto groups, revoked status, and name can be updated for now
|
||||
newKey := oldKey.Copy()
|
||||
newKey.Name = keyToSave.Name
|
||||
newKey.AutoGroups = keyToSave.AutoGroups
|
||||
newKey.Revoked = keyToSave.Revoked
|
||||
newKey.UpdatedAt = time.Now().UTC()
|
||||
if oldKey.Revoked && !keyToSave.Revoked {
|
||||
return status.Errorf(status.InvalidArgument, "can't un-revoke a revoked setup key")
|
||||
}
|
||||
|
||||
account.SetupKeys[newKey.Key] = newKey
|
||||
// only auto groups, revoked status (from false to true) can be updated
|
||||
newKey = oldKey.Copy()
|
||||
newKey.AutoGroups = keyToSave.AutoGroups
|
||||
newKey.Revoked = keyToSave.Revoked
|
||||
newKey.UpdatedAt = time.Now().UTC()
|
||||
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups)
|
||||
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups)
|
||||
|
||||
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -312,30 +339,9 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
am.StoreEvent(ctx, userID, newKey.Id, accountID, activity.SetupKeyRevoked, newKey.EventMeta())
|
||||
}
|
||||
|
||||
defer func() {
|
||||
addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups)
|
||||
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups)
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
|
||||
}
|
||||
}
|
||||
}()
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
return newKey, nil
|
||||
}
|
||||
@@ -347,16 +353,15 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.NewUnauthorizedToViewSetupKeysError()
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return setupKeys, nil
|
||||
return am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
||||
@@ -366,11 +371,15 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.NewUnauthorizedToViewSetupKeysError()
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -387,21 +396,29 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return status.NewUnauthorizedToViewSetupKeysError()
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var deletedSetupKey *SetupKey
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get setup key: %w", err)
|
||||
}
|
||||
|
||||
err = am.Store.DeleteSetupKey(ctx, accountID, keyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete setup key: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta())
|
||||
@@ -409,15 +426,62 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
|
||||
for _, group := range autoGroups {
|
||||
g, ok := account.Groups[group]
|
||||
func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, groupID := range autoGroupIDs {
|
||||
group, ok := groups[groupID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "group %s doesn't exist", group)
|
||||
return status.Errorf(status.NotFound, "group not found: %s", groupID)
|
||||
}
|
||||
if g.Name == "All" {
|
||||
return status.Errorf(status.InvalidArgument, "can't add All group to the setup key")
|
||||
|
||||
if group.IsGroupAll() {
|
||||
return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareSetupKeyEvents prepares a list of event functions to be stored.
|
||||
func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, g := range removedGroups {
|
||||
group, ok := groups[g]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g)
|
||||
continue
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
|
||||
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupRemovedFromSetupKey, meta)
|
||||
})
|
||||
}
|
||||
|
||||
for _, g := range addedGroups {
|
||||
group, ok := groups[g]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g)
|
||||
continue
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": key.Name}
|
||||
am.StoreEvent(ctx, userID, key.Id, accountID, activity.GroupAddedToSetupKey, meta)
|
||||
})
|
||||
}
|
||||
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
@@ -56,11 +56,9 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
autoGroups := []string{"group_1", "group_2"}
|
||||
newKeyName := "my-new-test-key"
|
||||
revoked := true
|
||||
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
||||
Id: key.Id,
|
||||
Name: newKeyName,
|
||||
Revoked: revoked,
|
||||
AutoGroups: autoGroups,
|
||||
}, userID)
|
||||
@@ -68,7 +66,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
|
||||
assertKey(t, newKey, keyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
|
||||
key.Id, time.Now().UTC(), autoGroups, true)
|
||||
|
||||
// check the corresponding events that should have been generated
|
||||
@@ -76,7 +74,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
|
||||
assert.NotNil(t, ev)
|
||||
assert.Equal(t, account.Id, ev.AccountID)
|
||||
assert.Equal(t, newKeyName, ev.Meta["name"])
|
||||
assert.Equal(t, keyName, ev.Meta["name"])
|
||||
assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"]))
|
||||
assert.NotEmpty(t, ev.Meta["key"])
|
||||
assert.Equal(t, userID, ev.InitiatorID)
|
||||
@@ -89,7 +87,6 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
autoGroups = append(autoGroups, groupAll.ID)
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
||||
Id: key.Id,
|
||||
Name: newKeyName,
|
||||
Revoked: revoked,
|
||||
AutoGroups: autoGroups,
|
||||
}, userID)
|
||||
@@ -213,22 +210,41 @@ func TestGetSetupKeys(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
})
|
||||
plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "group_2",
|
||||
Name: "group_name_2",
|
||||
Peers: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
type testCase struct {
|
||||
name string
|
||||
keyId string
|
||||
expectedFailure bool
|
||||
}
|
||||
|
||||
testCase1 := testCase{
|
||||
name: "Should get existing Setup Key",
|
||||
keyId: plainKey.Id,
|
||||
expectedFailure: false,
|
||||
}
|
||||
testCase2 := testCase{
|
||||
name: "Should fail to get non-existent Setup Key",
|
||||
keyId: "some key",
|
||||
expectedFailure: true,
|
||||
}
|
||||
|
||||
for _, tCase := range []testCase{testCase1, testCase2} {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
key, err := manager.GetSetupKey(context.Background(), account.Id, userID, tCase.keyId)
|
||||
|
||||
if tCase.expectedFailure {
|
||||
if err == nil {
|
||||
t.Fatal("expected to fail")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
assert.NotEqual(t, plainKey.Key, key.Key)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -449,3 +465,31 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// revoke the key
|
||||
updateKey := key.Copy()
|
||||
updateKey.Revoked = true
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// re-activate revoked key
|
||||
updateKey.Revoked = false
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID)
|
||||
assert.Error(t, err, "should not allow to update revoked key")
|
||||
|
||||
}
|
||||
|
||||
@@ -33,12 +33,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
storeSqliteFileName = "store.db"
|
||||
idQueryCondition = "id = ?"
|
||||
keyQueryCondition = "key = ?"
|
||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||
accountIDCondition = "account_id = ?"
|
||||
peerNotFoundFMT = "peer %s not found"
|
||||
storeSqliteFileName = "store.db"
|
||||
idQueryCondition = "id = ?"
|
||||
keyQueryCondition = "key = ?"
|
||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||
accountAndIDsQueryCondition = "account_id = ? AND id IN ?"
|
||||
accountIDCondition = "account_id = ?"
|
||||
peerNotFoundFMT = "peer %s not found"
|
||||
)
|
||||
|
||||
// SqlStore represents an account storage backed by a Sql DB persisted to disk
|
||||
@@ -485,9 +486,10 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
|
||||
result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
return nil, status.NewSetupKeyNotFoundError(setupKey)
|
||||
}
|
||||
return nil, status.NewSetupKeyNotFoundError(result.Error)
|
||||
log.WithContext(ctx).Errorf("failed to get account by setup key from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get account by setup key from store")
|
||||
}
|
||||
|
||||
if key.AccountID == "" {
|
||||
@@ -554,9 +556,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
|
||||
var users []*User
|
||||
result := s.db.Find(&users, accountIDCondition, accountID)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
@@ -568,15 +570,15 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*Us
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) {
|
||||
var groups []*nbgroup.Group
|
||||
result := s.db.Find(&groups, accountIDCondition, accountID)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting groups from store")
|
||||
log.WithContext(ctx).Errorf("failed to get account groups from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get account groups from the store")
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
@@ -756,9 +758,10 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
|
||||
result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
return "", status.NewSetupKeyNotFoundError(setupKey)
|
||||
}
|
||||
return "", status.NewSetupKeyNotFoundError(result.Error)
|
||||
log.WithContext(ctx).Errorf("failed to get account ID by setup key from store: %v", result.Error)
|
||||
return "", status.Errorf(status.Internal, "failed to get account ID by setup key from store")
|
||||
}
|
||||
|
||||
if accountID == "" {
|
||||
@@ -855,7 +858,6 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.NewUserNotFoundError(userID)
|
||||
}
|
||||
|
||||
return status.NewGetUserFromStoreError()
|
||||
}
|
||||
user.LastLogin = lastLogin
|
||||
@@ -986,9 +988,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
|
||||
First(&setupKey, keyQueryCondition, key)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
return nil, status.NewSetupKeyNotFoundError(key)
|
||||
}
|
||||
return nil, status.NewSetupKeyNotFoundError(result.Error)
|
||||
log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store")
|
||||
}
|
||||
return &setupKey, nil
|
||||
}
|
||||
@@ -1006,7 +1009,7 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "setup key not found")
|
||||
return status.NewSetupKeyNotFoundError(setupKeyID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1042,7 +1045,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
|
||||
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "group not found for account")
|
||||
return status.NewGroupNotFoundError(groupID)
|
||||
}
|
||||
|
||||
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
|
||||
@@ -1076,15 +1079,51 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||
result := s.db.Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
// GetPeerByID retrieves a peer by its ID and account ID.
|
||||
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
|
||||
var peer *nbpeer.Peer
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&peer, accountAndIDQueryCondition, accountID, peerID)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "peer not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get peer from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get peer from store")
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
// GetPeersByIDs retrieves peers by their IDs and account ID.
|
||||
func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) {
|
||||
var peers []*nbpeer.Peer
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers by ID's from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get peers by ID's from the store")
|
||||
}
|
||||
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
for _, peer := range peers {
|
||||
peersMap[peer.ID] = peer
|
||||
}
|
||||
|
||||
return peersMap, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to increment network serial count in store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
|
||||
startTime := time.Now()
|
||||
tx := s.db.Begin()
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
@@ -1095,12 +1134,21 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit().Error
|
||||
|
||||
err = tx.Commit().Error
|
||||
|
||||
log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime))
|
||||
if s.metrics != nil {
|
||||
s.metrics.StoreMetrics().CountTransactionDuration(time.Since(startTime))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||
return &SqlStore{
|
||||
db: tx,
|
||||
db: tx,
|
||||
storeEngine: s.storeEngine,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1152,12 +1200,22 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
|
||||
}
|
||||
|
||||
// GetGroupByID retrieves a group by ID and account ID.
|
||||
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
|
||||
return getRecordByID[nbgroup.Group](s.db.Preload(clause.Associations), lockStrength, groupID, accountID)
|
||||
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) {
|
||||
var group *nbgroup.Group
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewGroupNotFoundError(groupID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get group from store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get group from store")
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// GetGroupByName retrieves a group by name and account ID.
|
||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) {
|
||||
var group nbgroup.Group
|
||||
|
||||
// TODO: This fix is accepted for now, but if we need to handle this more frequently
|
||||
@@ -1169,25 +1227,72 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
query = query.Order("json_array_length(peers) DESC")
|
||||
}
|
||||
|
||||
result := query.First(&group, "name = ? and account_id = ?", groupName, accountID)
|
||||
result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "group not found")
|
||||
return nil, status.NewGroupNotFoundError(groupName)
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error)
|
||||
log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get group by name from store")
|
||||
}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
// GetGroupsByIDs retrieves groups by their IDs and account ID.
|
||||
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) {
|
||||
var groups []*nbgroup.Group
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store")
|
||||
}
|
||||
|
||||
groupsMap := make(map[string]*nbgroup.Group)
|
||||
for _, group := range groups {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
return groupsMap, nil
|
||||
}
|
||||
|
||||
// SaveGroup saves a group to the store.
|
||||
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
|
||||
log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save group to store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroup deletes a group from the database.
|
||||
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete group from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewGroupNotFoundError(groupID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroups deletes groups from the database.
|
||||
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(strength)}).
|
||||
Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete groups from store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountPolicies retrieves policies for an account.
|
||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
||||
return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID)
|
||||
@@ -1220,12 +1325,57 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt
|
||||
|
||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
|
||||
return getRecords[*SetupKey](s.db, lockStrength, accountID)
|
||||
var setupKeys []*SetupKey
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Find(&setupKeys, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get setup keys from store")
|
||||
}
|
||||
|
||||
return setupKeys, nil
|
||||
}
|
||||
|
||||
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
||||
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
|
||||
return getRecordByID[SetupKey](s.db, lockStrength, setupKeyID, accountID)
|
||||
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) {
|
||||
var setupKey *SetupKey
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewSetupKeyNotFoundError(setupKeyID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get setup key from store")
|
||||
}
|
||||
|
||||
return setupKey, nil
|
||||
}
|
||||
|
||||
// SaveSetupKey saves a setup key to the database.
|
||||
func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save setup key to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSetupKey deletes a setup key from the database.
|
||||
func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete setup key from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewSetupKeyNotFoundError(keyID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountNameServerGroups retrieves name server groups for an account.
|
||||
@@ -1238,10 +1388,6 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock
|
||||
return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID)
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error {
|
||||
return deleteRecordByID[SetupKey](s.db, LockingStrengthUpdate, keyID, accountID)
|
||||
}
|
||||
|
||||
// getRecords retrieves records from the database based on the account ID.
|
||||
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
|
||||
var record []T
|
||||
@@ -1274,21 +1420,3 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// deleteRecordByID deletes a record by its ID and account ID from the database.
|
||||
func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error {
|
||||
var record T
|
||||
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(record, accountAndIDQueryCondition, accountID, recordID)
|
||||
if err := result.Error; err != nil {
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
recordType := parts[len(parts)-1]
|
||||
|
||||
return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "record not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -14,11 +14,10 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
route2 "github.com/netbirdio/netbird/route"
|
||||
|
||||
@@ -1181,7 +1180,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
||||
t.Fatal("failed to save group")
|
||||
return err
|
||||
}
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID)
|
||||
if err != nil {
|
||||
t.Fatal("failed to get group")
|
||||
return err
|
||||
@@ -1201,7 +1200,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) {
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
users, err := store.GetAccountUsers(context.Background(), accountID)
|
||||
users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, users, len(account.Users))
|
||||
}
|
||||
@@ -1260,9 +1259,9 @@ func TestSqlite_GetGroupByName(t *testing.T) {
|
||||
}
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID)
|
||||
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "All", group.Name)
|
||||
require.True(t, group.IsGroupAll())
|
||||
}
|
||||
|
||||
func Test_DeleteSetupKeySuccessfully(t *testing.T) {
|
||||
@@ -1274,7 +1273,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) {
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID)
|
||||
err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, setupKeyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID)
|
||||
@@ -1290,6 +1289,278 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
nonExistingKeyID := "non-existing-key-id"
|
||||
|
||||
err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID)
|
||||
err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetGroupsByIDs(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
groupIDs []string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing groups by existing IDs",
|
||||
groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "empty group IDs list",
|
||||
groupIDs: []string{},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "non-existing group IDs",
|
||||
groupIDs: []string{"nonexistent1", "nonexistent2"},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "mixed existing and non-existing group IDs",
|
||||
groupIDs: []string{"cfefqs706sqkneg59g4g", "nonexistent"},
|
||||
expectedCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, tt.expectedCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveGroup(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
group := &nbgroup.Group{
|
||||
ID: "group-id",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
Peers: []string{"peer1", "peer2"},
|
||||
}
|
||||
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
|
||||
require.NoError(t, err)
|
||||
|
||||
savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, savedGroup, group)
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveGroups(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
groups := []*nbgroup.Group{
|
||||
{
|
||||
ID: "group-1",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
Peers: []string{"peer1", "peer2"},
|
||||
},
|
||||
{
|
||||
ID: "group-2",
|
||||
AccountID: accountID,
|
||||
Issued: "integration",
|
||||
Peers: []string{"peer3", "peer4"},
|
||||
},
|
||||
}
|
||||
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteGroup(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
groupID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "delete existing group",
|
||||
groupID: "cfefqs706sqkneg59g4g",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "delete non-existing group",
|
||||
groupID: "non-existing-group-id",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "delete with empty group ID",
|
||||
groupID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
|
||||
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, group)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteGroups(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
groupIDs []string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "delete multiple existing groups",
|
||||
groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "delete non-existing groups",
|
||||
groupIDs: []string{"non-existing-id-1", "non-existing-id-2"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "delete with empty group IDs list",
|
||||
groupIDs: []string{},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, groupID := range tt.groupIDs {
|
||||
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, group)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPeerByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
tests := []struct {
|
||||
name string
|
||||
peerID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing peer",
|
||||
peerID: "cfefqs706sqkneg59g4g",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing peer",
|
||||
peerID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty peer ID",
|
||||
peerID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, peer)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, peer)
|
||||
require.Equal(t, tt.peerID, peer.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPeersByIDs(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
tests := []struct {
|
||||
name string
|
||||
peerIDs []string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing peers by existing IDs",
|
||||
peerIDs: []string{"cfefqs706sqkneg59g4g", "cfeg6sf06sqkneg59g50"},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "empty peer IDs list",
|
||||
peerIDs: []string{},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "non-existing peer IDs",
|
||||
peerIDs: []string{"nonexistent1", "nonexistent2"},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "mixed existing and non-existing peer IDs",
|
||||
peerIDs: []string{"cfeg6sf06sqkneg59g50", "nonexistent"},
|
||||
expectedCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, tt.expectedCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package status
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -103,22 +102,27 @@ func NewPeerLoginExpiredError() error {
|
||||
}
|
||||
|
||||
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
|
||||
func NewSetupKeyNotFoundError(err error) error {
|
||||
return Errorf(NotFound, "setup key not found: %s", err)
|
||||
func NewSetupKeyNotFoundError(setupKeyID string) error {
|
||||
return Errorf(NotFound, "setup key: %s not found", setupKeyID)
|
||||
}
|
||||
|
||||
func NewGetAccountFromStoreError(err error) error {
|
||||
return Errorf(Internal, "issue getting account from store: %s", err)
|
||||
}
|
||||
|
||||
// NewUserNotPartOfAccountError creates a new Error with PermissionDenied type for a user not being part of an account
|
||||
func NewUserNotPartOfAccountError() error {
|
||||
return Errorf(PermissionDenied, "user is not part of this account")
|
||||
}
|
||||
|
||||
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
|
||||
func NewGetUserFromStoreError() error {
|
||||
return Errorf(Internal, "issue getting user from store")
|
||||
}
|
||||
|
||||
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
|
||||
func NewStoreContextCanceledError(duration time.Duration) error {
|
||||
return Errorf(Internal, "store access: context canceled after %v", duration)
|
||||
// NewAdminPermissionError creates a new Error with PermissionDenied type for actions requiring admin role.
|
||||
func NewAdminPermissionError() error {
|
||||
return Errorf(PermissionDenied, "admin role required to perform this action")
|
||||
}
|
||||
|
||||
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
|
||||
@@ -126,7 +130,12 @@ func NewInvalidKeyIDError() error {
|
||||
return Errorf(InvalidArgument, "invalid key ID")
|
||||
}
|
||||
|
||||
// NewUnauthorizedToViewSetupKeysError creates a new Error with Unauthorized type for an issue getting a setup key
|
||||
func NewUnauthorizedToViewSetupKeysError() error {
|
||||
return Errorf(Unauthorized, "only users with admin power can view setup keys")
|
||||
// NewGetAccountError creates a new Error with Internal type for an issue getting account
|
||||
func NewGetAccountError(err error) error {
|
||||
return Errorf(Internal, "error getting account: %s", err)
|
||||
}
|
||||
|
||||
// NewGroupNotFoundError creates a new Error with NotFound type for a missing group
|
||||
func NewGroupNotFoundError(groupID string) error {
|
||||
return Errorf(NotFound, "group: %s not found", groupID)
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ type Store interface {
|
||||
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
|
||||
SaveUsers(accountID string, users map[string]*User) error
|
||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
@@ -70,11 +70,14 @@ type Store interface {
|
||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
|
||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
||||
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error)
|
||||
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
|
||||
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
|
||||
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
|
||||
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
|
||||
|
||||
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
||||
@@ -89,6 +92,8 @@ type Store interface {
|
||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
|
||||
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||
@@ -96,7 +101,9 @@ type Store interface {
|
||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
|
||||
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
|
||||
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error)
|
||||
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error
|
||||
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
|
||||
|
||||
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
|
||||
@@ -105,7 +112,7 @@ type Store interface {
|
||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||
|
||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
|
||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||
|
||||
GetInstallationID() string
|
||||
@@ -124,7 +131,6 @@ type Store interface {
|
||||
// This is also a method of metrics.DataSource interface.
|
||||
GetStoreEngine() StoreEngine
|
||||
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||
DeleteSetupKey(ctx context.Context, accountID, keyID string) error
|
||||
}
|
||||
|
||||
type StoreEngine string
|
||||
|
||||
@@ -13,6 +13,7 @@ type AccountManagerMetrics struct {
|
||||
updateAccountPeersDurationMs metric.Float64Histogram
|
||||
getPeerNetworkMapDurationMs metric.Float64Histogram
|
||||
networkMapObjectCount metric.Int64Histogram
|
||||
peerMetaUpdateCount metric.Int64Counter
|
||||
}
|
||||
|
||||
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
|
||||
@@ -44,11 +45,17 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccountManagerMetrics{
|
||||
ctx: ctx,
|
||||
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
|
||||
updateAccountPeersDurationMs: updateAccountPeersDurationMs,
|
||||
networkMapObjectCount: networkMapObjectCount,
|
||||
peerMetaUpdateCount: peerMetaUpdateCount,
|
||||
}, nil
|
||||
|
||||
}
|
||||
@@ -67,3 +74,8 @@ func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration ti
|
||||
func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
|
||||
metrics.networkMapObjectCount.Record(metrics.ctx, count)
|
||||
}
|
||||
|
||||
// CountPeerMetUpdate counts the number of peer meta updates
|
||||
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
|
||||
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ type StoreMetrics struct {
|
||||
globalLockAcquisitionDurationMs metric.Int64Histogram
|
||||
persistenceDurationMicro metric.Int64Histogram
|
||||
persistenceDurationMs metric.Int64Histogram
|
||||
transactionDurationMs metric.Int64Histogram
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
@@ -40,11 +41,17 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er
|
||||
return nil, err
|
||||
}
|
||||
|
||||
transactionDurationMs, err := meter.Int64Histogram("management.store.transaction.duration.ms")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &StoreMetrics{
|
||||
globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro,
|
||||
globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs,
|
||||
persistenceDurationMicro: persistenceDurationMicro,
|
||||
persistenceDurationMs: persistenceDurationMs,
|
||||
transactionDurationMs: transactionDurationMs,
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
@@ -60,3 +67,8 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) {
|
||||
metrics.persistenceDurationMicro.Record(metrics.ctx, duration.Microseconds())
|
||||
metrics.persistenceDurationMs.Record(metrics.ctx, duration.Milliseconds())
|
||||
}
|
||||
|
||||
// CountTransactionDuration counts the duration of a store persistence operation
|
||||
func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) {
|
||||
metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds())
|
||||
}
|
||||
|
||||
@@ -96,9 +96,12 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
|
||||
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
|
||||
log.WithContext(ctx).Debugf("closing updates channel: peer %s has no channel", peerID)
|
||||
}
|
||||
|
||||
// CloseChannels closes updates channel for each given peer
|
||||
|
||||
@@ -9,14 +9,16 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integration_reference"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -103,6 +105,11 @@ func (u *User) IsAdminOrServiceUser() bool {
|
||||
return u.HasAdminPower() || u.IsServiceUser
|
||||
}
|
||||
|
||||
// IsRegularUser checks if the user is a regular user.
|
||||
func (u *User) IsRegularUser() bool {
|
||||
return !u.HasAdminPower() && !u.IsServiceUser
|
||||
}
|
||||
|
||||
// ToUserInfo converts a User object to a UserInfo object.
|
||||
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
|
||||
autoGroups := u.AutoGroups
|
||||
@@ -487,7 +494,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
|
||||
|
||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, account.Id)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -798,15 +805,20 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
expiredPeers = append(expiredPeers, blockedPeers...)
|
||||
}
|
||||
|
||||
peerGroupsAdded := make(map[string][]string)
|
||||
peerGroupsRemoved := make(map[string][]string)
|
||||
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
|
||||
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
// need force update all auto groups in any case they will not be duplicated
|
||||
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
|
||||
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
|
||||
peerGroupsAdded = account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
|
||||
peerGroupsRemoved = account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
|
||||
}
|
||||
|
||||
events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
userUpdateEvents := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
|
||||
eventsToStore = append(eventsToStore, userUpdateEvents...)
|
||||
|
||||
userGroupsEvents := am.prepareUserGroupsEvents(ctx, initiatorUser.Id, oldUser, newUser, account, peerGroupsAdded, peerGroupsRemoved)
|
||||
eventsToStore = append(eventsToStore, userGroupsEvents...)
|
||||
|
||||
updatedUserInfo, err := getUserInfo(ctx, am, newUser, account)
|
||||
if err != nil {
|
||||
@@ -828,7 +840,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, account.Id)
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
@@ -865,32 +877,78 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in
|
||||
})
|
||||
}
|
||||
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() {
|
||||
var eventsToStore []func()
|
||||
if newUser.AutoGroups != nil {
|
||||
removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups)
|
||||
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
})
|
||||
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
}
|
||||
}
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
})
|
||||
}
|
||||
removedEvents := am.handleGroupRemovedFromUser(ctx, initiatorUserID, oldUser, newUser, account, removedGroups, peerGroupsRemoved)
|
||||
eventsToStore = append(eventsToStore, removedEvents...)
|
||||
|
||||
addedEvents := am.handleGroupAddedToUser(ctx, initiatorUserID, oldUser, newUser, account, addedGroups, peerGroupsAdded)
|
||||
eventsToStore = append(eventsToStore, addedEvents...)
|
||||
}
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() {
|
||||
var eventsToStore []func()
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
})
|
||||
}
|
||||
}
|
||||
for groupID, peerIDs := range peerGroupsAdded {
|
||||
group := account.GetGroup(groupID)
|
||||
for _, peerID := range peerIDs {
|
||||
peer := account.GetPeer(peerID)
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{
|
||||
"group": group.Name, "group_id": group.ID,
|
||||
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
}
|
||||
am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupAddedToPeer, meta)
|
||||
})
|
||||
}
|
||||
}
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() {
|
||||
var eventsToStore []func()
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
})
|
||||
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
}
|
||||
}
|
||||
for groupID, peerIDs := range peerGroupsRemoved {
|
||||
group := account.GetGroup(groupID)
|
||||
for _, peerID := range peerIDs {
|
||||
peer := account.GetPeer(peerID)
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{
|
||||
"group": group.Name, "group_id": group.ID,
|
||||
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
}
|
||||
am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupRemovedFromPeer, meta)
|
||||
})
|
||||
}
|
||||
}
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
@@ -1100,6 +1158,9 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error {
|
||||
var peerIDs []string
|
||||
for _, peer := range peers {
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key)
|
||||
|
||||
if peer.Status.LoginExpired {
|
||||
continue
|
||||
}
|
||||
@@ -1107,8 +1168,11 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
||||
peer.MarkLoginExpired(true)
|
||||
account.UpdatePeer(peer)
|
||||
if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed saving peer status for peer %s: %s", peer.ID, err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("mark peer %s login expired", peer.ID)
|
||||
|
||||
am.StoreEvent(
|
||||
ctx,
|
||||
peer.UserID, peer.ID, account.Id,
|
||||
@@ -1119,7 +1183,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
||||
if len(peerIDs) != 0 {
|
||||
// this will trigger peer disconnect from the management service
|
||||
am.peersUpdateManager.CloseChannels(ctx, peerIDs)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, account.Id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1227,7 +1291,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
for targetUserID, meta := range deletedUsersMeta {
|
||||
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
|
||||
const (
|
||||
bufferSize = 8820
|
||||
|
||||
errCloseConn = "failed to close connection to peer: %s"
|
||||
)
|
||||
|
||||
// Peer represents a peer connection
|
||||
@@ -46,6 +48,12 @@ func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *
|
||||
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
|
||||
// the message accordingly.
|
||||
func (p *Peer) Work() {
|
||||
defer func() {
|
||||
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
p.log.Errorf(errCloseConn, err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
@@ -97,7 +105,7 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
|
||||
case messages.MsgTypeClose:
|
||||
p.log.Infof("peer exited gracefully")
|
||||
if err := p.conn.Close(); err != nil {
|
||||
log.Errorf("failed to close connection to peer: %s", err)
|
||||
log.Errorf(errCloseConn, err)
|
||||
}
|
||||
default:
|
||||
p.log.Warnf("received unexpected message type: %s", msgType)
|
||||
@@ -121,9 +129,8 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
|
||||
p.log.Errorf("failed to send close message to peer: %s", p.String())
|
||||
}
|
||||
|
||||
err = p.conn.Close()
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||
if err := p.conn.Close(); err != nil {
|
||||
p.log.Errorf(errCloseConn, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,7 +139,7 @@ func (p *Peer) Close() {
|
||||
defer p.connMu.Unlock()
|
||||
|
||||
if err := p.conn.Close(); err != nil {
|
||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||
p.log.Errorf(errCloseConn, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
68
util/file.go
68
util/file.go
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@@ -14,8 +15,21 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func WriteBytesWithRestrictedPermission(ctx context.Context, file string, bs []byte) error {
|
||||
configDir, configFileName, err := prepareConfigFileDir(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("prepare config file dir: %w", err)
|
||||
}
|
||||
|
||||
if err = EnforcePermission(file); err != nil {
|
||||
return fmt.Errorf("enforce permission: %w", err)
|
||||
}
|
||||
|
||||
return writeBytes(ctx, file, err, configDir, configFileName, bs)
|
||||
}
|
||||
|
||||
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
|
||||
func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
|
||||
func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error {
|
||||
configDir, configFileName, err := prepareConfigFileDir(file)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -26,18 +40,18 @@ func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeJson(file, obj, configDir, configFileName)
|
||||
return writeJson(ctx, file, obj, configDir, configFileName)
|
||||
}
|
||||
|
||||
// WriteJson writes JSON config object to a file creating parent directories if required
|
||||
// The output JSON is pretty-formatted
|
||||
func WriteJson(file string, obj interface{}) error {
|
||||
func WriteJson(ctx context.Context, file string, obj interface{}) error {
|
||||
configDir, configFileName, err := prepareConfigFileDir(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeJson(file, obj, configDir, configFileName)
|
||||
return writeJson(ctx, file, obj, configDir, configFileName)
|
||||
}
|
||||
|
||||
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
|
||||
@@ -79,24 +93,47 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeJson(file string, obj interface{}, configDir string, configFileName string) error {
|
||||
func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error {
|
||||
// Check context before expensive operations
|
||||
if ctx.Err() != nil {
|
||||
return fmt.Errorf("write json start: %w", ctx.Err())
|
||||
}
|
||||
|
||||
// make it pretty
|
||||
bs, err := json.MarshalIndent(obj, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("marshal: %w", err)
|
||||
}
|
||||
|
||||
return writeBytes(ctx, file, err, configDir, configFileName, bs)
|
||||
}
|
||||
|
||||
func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error {
|
||||
if ctx.Err() != nil {
|
||||
return fmt.Errorf("write bytes start: %w", ctx.Err())
|
||||
}
|
||||
|
||||
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("create temp: %w", err)
|
||||
}
|
||||
|
||||
tempFileName := tempFile.Name()
|
||||
// closing file ops as windows doesn't allow to move it
|
||||
err = tempFile.Close()
|
||||
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) {
|
||||
log.Warnf("failed to set deadline: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = tempFile.Write(bs)
|
||||
if err != nil {
|
||||
return err
|
||||
_ = tempFile.Close()
|
||||
return fmt.Errorf("write: %w", err)
|
||||
}
|
||||
|
||||
if err = tempFile.Close(); err != nil {
|
||||
return fmt.Errorf("close %s: %w", tempFileName, err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
@@ -106,14 +143,13 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st
|
||||
}
|
||||
}()
|
||||
|
||||
err = os.WriteFile(tempFileName, bs, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
// Check context again
|
||||
if ctx.Err() != nil {
|
||||
return fmt.Errorf("after temp file: %w", ctx.Err())
|
||||
}
|
||||
|
||||
err = os.Rename(tempFileName, file)
|
||||
if err != nil {
|
||||
return err
|
||||
if err = os.Rename(tempFileName, file); err != nil {
|
||||
return fmt.Errorf("move %s to %s: %w", tempFileName, file, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
@@ -39,7 +40,7 @@ func TestConfigJSON(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
err := WriteJson(tmpDir+"/testconfig.json", tt.config)
|
||||
err := WriteJson(context.Background(), tmpDir+"/testconfig.json", tt.config)
|
||||
require.NoError(t, err)
|
||||
|
||||
read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
|
||||
@@ -73,7 +74,7 @@ func TestCopyFileContents(t *testing.T) {
|
||||
src := tmpDir + "/copytest_src"
|
||||
dst := tmpDir + "/copytest_dst"
|
||||
|
||||
err := WriteJson(src, tt.srcContent)
|
||||
err := WriteJson(context.Background(), src, tt.srcContent)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = CopyFileContents(src, dst)
|
||||
@@ -127,7 +128,7 @@ func TestHandleConfigFileWithoutFullPath(t *testing.T) {
|
||||
_ = os.Remove(cfgFile)
|
||||
}()
|
||||
|
||||
err := WriteJson(cfgFile, tt.config)
|
||||
err := WriteJson(context.Background(), cfgFile, tt.config)
|
||||
require.NoError(t, err)
|
||||
|
||||
read, err := ReadJson(cfgFile, &TestConfig{})
|
||||
|
||||
@@ -3,6 +3,9 @@ package grpc
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"net"
|
||||
"os/user"
|
||||
"runtime"
|
||||
@@ -23,20 +26,22 @@ func WithCustomDialer() grpc.DialOption {
|
||||
if runtime.GOOS == "linux" {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to get current user: %v", err)
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
||||
if currentUser.Uid != "0" {
|
||||
log.Debug("Not running as root, using standard dialer")
|
||||
dialer := &net.Dialer{}
|
||||
return dialer.DialContext(ctx, "tcp", addr)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Using nbnet.NewDialer()")
|
||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to dial: %s", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
})
|
||||
|
||||
@@ -69,7 +69,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
|
||||
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial: %w", err)
|
||||
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
|
||||
}
|
||||
|
||||
// Wrap the connection in Conn to handle Close with hooks
|
||||
|
||||
@@ -4,9 +4,14 @@ package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
||||
|
||||
// SetSocketMark sets the SO_MARK option on the given socket connection
|
||||
func SetSocketMark(conn syscall.Conn) error {
|
||||
sysconn, err := conn.SyscallConn()
|
||||
@@ -36,6 +41,13 @@ func SetRawSocketMark(conn syscall.RawConn) error {
|
||||
|
||||
func SetSocketOpt(fd int) error {
|
||||
if CustomRoutingDisabled() {
|
||||
log.Infof("Custom routing is disabled, skipping SO_MARK")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for the new environment variable
|
||||
if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" {
|
||||
log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user