Compare commits

...

18 Commits

Author SHA1 Message Date
crn4
3945d2b170 bufferUpdate mu unlock after op finish 2025-07-09 17:53:37 +02:00
crn4
0957defa54 some logs over buffer update 2025-07-09 15:34:33 +02:00
crn4
0c6ab1de30 added cleanupWindow for collecting several ephemeral peers to delete and run BufferUpdateAccountPeers once 2025-07-09 00:15:45 +02:00
Pedro Costa
2e18d77d40 further optimization to ensure db roundtrip and calculations only occur if there are peers to update 2025-07-08 09:56:15 +01:00
Pedro Costa
3a8a6fcb76 further test fixes 2025-07-08 09:36:16 +01:00
Pedro Costa
49f083a372 fix dns and client tests 2025-07-08 09:07:26 +01:00
Pedro Costa
7d9ca73f6c fix ephemeral test 2025-07-08 08:58:31 +01:00
Pedro Costa
470b80c1b8 get extra settings only once per updateaccountpeers 2025-07-08 08:42:45 +01:00
Maycon Santos
ad0b78a7ac add request id to ephemeral cleanup 2025-07-08 02:04:05 +02:00
Maycon Santos
7aa2ca87f2 split call to BufferUpdateAccountPeers when system user initiates 2025-07-08 01:41:07 +02:00
Maycon Santos
4c58088311 fix go mod 2025-07-07 19:38:02 +02:00
Maycon Santos
bcccd65008 add rate limit 2025-07-07 19:35:48 +02:00
Maycon Santos
1ffc8933de add rate limit 2025-07-07 19:15:54 +02:00
Maycon Santos
ad22e9eea1 Merge branch 'main' into add-account-onboarding 2025-07-02 02:59:03 +02:00
Maycon Santos
d806fc4a03 handle empty onboard to avoid breaking clients and dashboard 2025-07-02 01:38:40 +02:00
Maycon Santos
7a5edb3894 create accounts with pending onboarding 2025-07-01 23:25:52 +02:00
Maycon Santos
2dc230ab9a add store and account manager methods
add store tests
2025-07-01 19:54:53 +02:00
Maycon Santos
432dc42bf5 add account onboarding 2025-07-01 11:51:46 +02:00
25 changed files with 753 additions and 147 deletions

View File

@@ -102,6 +102,11 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()). GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil). Return(&types.Settings{}, nil).
AnyTimes() AnyTimes()
settingsMockManager.
EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil { if err != nil {

2
go.mod
View File

@@ -105,6 +105,7 @@ require (
golang.org/x/oauth2 v0.24.0 golang.org/x/oauth2 v0.24.0
golang.org/x/sync v0.13.0 golang.org/x/sync v0.13.0
golang.org/x/term v0.31.0 golang.org/x/term v0.31.0
golang.org/x/time v0.5.0
google.golang.org/api v0.177.0 google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7 gorm.io/driver/mysql v1.5.7
@@ -240,7 +241,6 @@ require (
golang.org/x/image v0.18.0 // indirect golang.org/x/image v0.18.0 // indirect
golang.org/x/mod v0.17.0 // indirect golang.org/x/mod v0.17.0 // indirect
golang.org/x/text v0.24.0 // indirect golang.org/x/text v0.24.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect

View File

@@ -87,6 +87,12 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
). ).
Return(&types.Settings{}, nil). Return(&types.Settings{}, nil).
AnyTimes() AnyTimes()
settingsMockManager.
EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl)
permissionsManagerMock. permissionsManagerMock.
EXPECT(). EXPECT().

View File

@@ -1192,6 +1192,71 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s
return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID)
} }
// GetAccountOnboarding retrieves the onboarding information for a specific account.
func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
log.Errorf("failed to get account onboarding for accountssssssss %s: %v", accountID, err)
return nil, err
}
if onboarding == nil {
onboarding = &types.AccountOnboarding{
AccountID: accountID,
}
}
return onboarding, nil
}
func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
return nil, fmt.Errorf("failed to get account onboarding: %w", err)
}
if oldOnboarding == nil {
oldOnboarding = &types.AccountOnboarding{
AccountID: accountID,
}
}
if newOnboarding == nil {
return oldOnboarding, nil
}
if oldOnboarding.IsEqual(*newOnboarding) {
log.WithContext(ctx).Debugf("no changes in onboarding for account %s", accountID)
return oldOnboarding, nil
}
newOnboarding.AccountID = accountID
err = am.Store.SaveAccountOnboarding(ctx, newOnboarding)
if err != nil {
return nil, fmt.Errorf("failed to update account onboarding: %w", err)
}
return newOnboarding, nil
}
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
if userAuth.UserId == "" { if userAuth.UserId == "" {
return "", "", errors.New(emptyUserID) return "", "", errors.New(emptyUserID)
@@ -1733,6 +1798,10 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
RoutingPeerDNSResolutionEnabled: true, RoutingPeerDNSResolutionEnabled: true,
}, },
Onboarding: types.AccountOnboarding{
OnboardingFlowPending: true,
SignupFormPending: true,
},
} }
if err := acc.AddAllGroup(disableDefaultPolicy); err != nil { if err := acc.AddAllGroup(disableDefaultPolicy); err != nil {

View File

@@ -39,6 +39,7 @@ type Manager interface {
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error)
GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error)
GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error)
AccountExists(ctx context.Context, accountID string) (bool, error) AccountExists(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
@@ -89,6 +90,7 @@ type Manager interface {
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error) GetAllConnectedPeers() (map[string]struct{}, error)
@@ -110,6 +112,7 @@ type Manager interface {
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
UpdateAccountPeers(ctx context.Context, accountID string) UpdateAccountPeers(ctx context.Context, accountID string)
BufferUpdateAccountPeers(ctx context.Context, accountID string)
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
GetStore() store.Store GetStore() store.Store

View File

@@ -3440,3 +3440,74 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
} }
}) })
} }
func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
require.NoError(t, err)
t.Run("should return account onboarding when onboarding exist", func(t *testing.T) {
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
require.NoError(t, err)
require.NotNil(t, onboarding)
assert.Equal(t, account.Id, onboarding.AccountID)
assert.Equal(t, true, onboarding.OnboardingFlowPending)
assert.Equal(t, true, onboarding.SignupFormPending)
if onboarding.UpdatedAt.IsZero() {
t.Errorf("Onboarding was not retrieved from the store")
}
})
t.Run("should return account onboarding when onboard don't exist", func(t *testing.T) {
account.Id = "with-zero-onboarding"
account.Onboarding = types.AccountOnboarding{}
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
require.NoError(t, err)
require.NotNil(t, onboarding)
_, err = manager.Store.GetAccountOnboarding(context.Background(), account.Id)
require.Error(t, err, "should return error when onboarding is not set")
})
}
func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
require.NoError(t, err)
onboarding := &types.AccountOnboarding{
OnboardingFlowPending: true,
SignupFormPending: true,
}
t.Run("update onboarding with no change", func(t *testing.T) {
updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding)
require.NoError(t, err)
assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending)
assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending)
if updated.UpdatedAt.IsZero() {
t.Errorf("Onboarding was updated in the store")
}
})
onboarding.OnboardingFlowPending = false
onboarding.SignupFormPending = false
t.Run("update onboarding", func(t *testing.T) {
updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding)
require.NoError(t, err)
require.NotNil(t, updated)
assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending)
assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending)
})
t.Run("update onboarding with no onboarding", func(t *testing.T) {
_, err = manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, nil)
require.NoError(t, err)
})
}

View File

@@ -216,6 +216,8 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
// return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }

View File

@@ -5,16 +5,20 @@ import (
"sync" "sync"
"time" "time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbAccount "github.com/netbirdio/netbird/management/server/account" nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
) )
const ( const (
ephemeralLifeTime = 10 * time.Minute ephemeralLifeTime = 10 * time.Minute
// cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure.
cleanupWindow = 1 * time.Minute
) )
var ( var (
@@ -41,6 +45,9 @@ type EphemeralManager struct {
tailPeer *ephemeralPeer tailPeer *ephemeralPeer
peersLock sync.Mutex peersLock sync.Mutex
timer *time.Timer timer *time.Timer
lifeTime time.Duration
cleanupWindow time.Duration
} }
// NewEphemeralManager instantiate new EphemeralManager // NewEphemeralManager instantiate new EphemeralManager
@@ -48,6 +55,9 @@ func NewEphemeralManager(store store.Store, accountManager nbAccount.Manager) *E
return &EphemeralManager{ return &EphemeralManager{
store: store, store: store,
accountManager: accountManager, accountManager: accountManager,
lifeTime: ephemeralLifeTime,
cleanupWindow: cleanupWindow,
} }
} }
@@ -60,7 +70,7 @@ func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) {
e.loadEphemeralPeers(ctx) e.loadEphemeralPeers(ctx)
if e.headPeer != nil { if e.headPeer != nil {
e.timer = time.AfterFunc(ephemeralLifeTime, func() { e.timer = time.AfterFunc(e.lifeTime, func() {
e.cleanup(ctx) e.cleanup(ctx)
}) })
} }
@@ -113,9 +123,13 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
return return
} }
e.addPeer(peer.AccountID, peer.ID, newDeadLine()) e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
if e.timer == nil { if e.timer == nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
if delay < 0 {
delay = 0
}
e.timer = time.AfterFunc(delay, func() {
e.cleanup(ctx) e.cleanup(ctx)
}) })
} }
@@ -128,7 +142,7 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
return return
} }
t := newDeadLine() t := e.newDeadLine()
for _, p := range peers { for _, p := range peers {
e.addPeer(p.AccountID, p.ID, t) e.addPeer(p.AccountID, p.ID, t)
} }
@@ -138,6 +152,9 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
func (e *EphemeralManager) cleanup(ctx context.Context) { func (e *EphemeralManager) cleanup(ctx context.Context) {
log.Tracef("on ephemeral cleanup") log.Tracef("on ephemeral cleanup")
reqID := uuid.New().String()
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
deletePeers := make(map[string]*ephemeralPeer) deletePeers := make(map[string]*ephemeralPeer)
e.peersLock.Lock() e.peersLock.Lock()
@@ -155,7 +172,11 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
} }
if e.headPeer != nil { if e.headPeer != nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
if delay < 0 {
delay = 0
}
e.timer = time.AfterFunc(delay, func() {
e.cleanup(ctx) e.cleanup(ctx)
}) })
} else { } else {
@@ -164,13 +185,21 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
e.peersLock.Unlock() e.peersLock.Unlock()
bufferAccountCall := make(map[string]struct{})
for id, p := range deletePeers { for id, p := range deletePeers {
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator) err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
} else {
bufferAccountCall[p.accountID] = struct{}{}
} }
} }
for accountID := range bufferAccountCall {
log.WithContext(ctx).Debugf("ephemeral - buffer update account peers for account: %s", accountID)
e.accountManager.BufferUpdateAccountPeers(ctx, accountID)
}
} }
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) { func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
@@ -223,6 +252,6 @@ func (e *EphemeralManager) isPeerOnList(id string) bool {
return false return false
} }
func newDeadLine() time.Time { func (e *EphemeralManager) newDeadLine() time.Time {
return timeNow().Add(ephemeralLifeTime) return timeNow().Add(e.lifeTime)
} }

View File

@@ -3,9 +3,12 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"sync"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
nbAccount "github.com/netbirdio/netbird/management/server/account" nbAccount "github.com/netbirdio/netbird/management/server/account"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -27,28 +30,65 @@ func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ store.LockingStren
return peers, nil return peers, nil
} }
type MocAccountManager struct { type MockAccountManager struct {
mu sync.Mutex
nbAccount.Manager nbAccount.Manager
store *MockStore store *MockStore
deletePeerCalls int
bufferUpdateCalls map[string]int
wg *sync.WaitGroup
} }
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error { func (a *MockAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
a.mu.Lock()
defer a.mu.Unlock()
a.deletePeerCalls++
if a.wg != nil {
a.wg.Done()
}
delete(a.store.account.Peers, peerID) delete(a.store.account.Peers, peerID)
return nil //nolint:nil return nil
} }
func (a MocAccountManager) GetStore() store.Store { func (a *MockAccountManager) GetDeletePeerCalls() int {
a.mu.Lock()
defer a.mu.Unlock()
return a.deletePeerCalls
}
func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
a.mu.Lock()
defer a.mu.Unlock()
if a.bufferUpdateCalls == nil {
a.bufferUpdateCalls = make(map[string]int)
}
a.bufferUpdateCalls[accountID]++
}
func (a *MockAccountManager) GetBufferUpdateCalls(accountID string) int {
a.mu.Lock()
defer a.mu.Unlock()
if a.bufferUpdateCalls == nil {
return 0
}
return a.bufferUpdateCalls[accountID]
}
func (a *MockAccountManager) GetStore() store.Store {
return a.store return a.store
} }
func TestNewManager(t *testing.T) { func TestNewManager(t *testing.T) {
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now() startTime := time.Now()
timeNow = func() time.Time { timeNow = func() time.Time {
return startTime return startTime
} }
store := &MockStore{} store := &MockStore{}
am := MocAccountManager{ am := MockAccountManager{
store: store, store: store,
} }
@@ -56,7 +96,7 @@ func TestNewManager(t *testing.T) {
numberOfEphemeralPeers := 3 numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers) seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, am) mgr := NewEphemeralManager(store, &am)
mgr.loadEphemeralPeers(context.Background()) mgr.loadEphemeralPeers(context.Background())
startTime = startTime.Add(ephemeralLifeTime + 1) startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup(context.Background()) mgr.cleanup(context.Background())
@@ -67,13 +107,16 @@ func TestNewManager(t *testing.T) {
} }
func TestNewManagerPeerConnected(t *testing.T) { func TestNewManagerPeerConnected(t *testing.T) {
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now() startTime := time.Now()
timeNow = func() time.Time { timeNow = func() time.Time {
return startTime return startTime
} }
store := &MockStore{} store := &MockStore{}
am := MocAccountManager{ am := MockAccountManager{
store: store, store: store,
} }
@@ -81,7 +124,7 @@ func TestNewManagerPeerConnected(t *testing.T) {
numberOfEphemeralPeers := 3 numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers) seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, am) mgr := NewEphemeralManager(store, &am)
mgr.loadEphemeralPeers(context.Background()) mgr.loadEphemeralPeers(context.Background())
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
@@ -95,13 +138,16 @@ func TestNewManagerPeerConnected(t *testing.T) {
} }
func TestNewManagerPeerDisconnected(t *testing.T) { func TestNewManagerPeerDisconnected(t *testing.T) {
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now() startTime := time.Now()
timeNow = func() time.Time { timeNow = func() time.Time {
return startTime return startTime
} }
store := &MockStore{} store := &MockStore{}
am := MocAccountManager{ am := MockAccountManager{
store: store, store: store,
} }
@@ -109,7 +155,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
numberOfEphemeralPeers := 3 numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers) seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, am) mgr := NewEphemeralManager(store, &am)
mgr.loadEphemeralPeers(context.Background()) mgr.loadEphemeralPeers(context.Background())
for _, v := range store.account.Peers { for _, v := range store.account.Peers {
mgr.OnPeerConnected(context.Background(), v) mgr.OnPeerConnected(context.Background(), v)
@@ -126,6 +172,36 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
} }
} }
func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
const (
ephemeralPeers = 10
testLifeTime = 1 * time.Second
testCleanupWindow = 100 * time.Millisecond
)
mockStore := &MockStore{}
mockAM := &MockAccountManager{
store: mockStore,
}
mockAM.wg = &sync.WaitGroup{}
mockAM.wg.Add(ephemeralPeers)
mgr := NewEphemeralManager(mockStore, mockAM)
mgr.lifeTime = testLifeTime
mgr.cleanupWindow = testCleanupWindow
account := newAccountWithId(context.Background(), "account", "", "", false)
mockStore.account = account
for i := range ephemeralPeers {
p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true}
mockStore.account.Peers[p.ID] = p
time.Sleep(testCleanupWindow / ephemeralPeers)
mgr.OnPeerDisconnected(context.Background(), p)
}
mockAM.wg.Wait()
assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime")
assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once")
assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted only the first peer")
}
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) { func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
store.account = newAccountWithId(context.Background(), "my account", "", "", false) store.account = newAccountWithId(context.Background(), "my account", "", "", false)

View File

@@ -5,6 +5,8 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -13,6 +15,7 @@ import (
"github.com/golang/protobuf/ptypes/timestamp" "github.com/golang/protobuf/ptypes/timestamp"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/time/rate"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
@@ -47,6 +50,10 @@ type GRPCServer struct {
ephemeralManager *EphemeralManager ephemeralManager *EphemeralManager
peerLocks sync.Map peerLocks sync.Map
authManager auth.Manager authManager auth.Manager
syncLimiter *rate.Limiter
loginLimiterStore sync.Map
loginPeerBooster int
loginPeerLimit rate.Limit
} }
// NewServer creates a new Management server // NewServer creates a new Management server
@@ -76,6 +83,41 @@ func NewServer(
} }
} }
multiplier := time.Minute
d, e := time.ParseDuration(os.Getenv("NB_LOGIN_RATE"))
if e == nil {
multiplier = d
}
loginRatePerS, err := strconv.Atoi(os.Getenv("NB_LOGIN_RATE_PER_M"))
if loginRatePerS == 0 || err != nil {
loginRatePerS = 200
}
loginBurst, err := strconv.Atoi(os.Getenv("NB_LOGIN_BURST"))
if loginBurst == 0 || err != nil {
loginBurst = 200
}
log.WithContext(ctx).Infof("login burst limit set to %d", loginBurst)
loginPeerRatePerS, err := strconv.Atoi(os.Getenv("NB_LOGIN_PEER_RATE_PER_M"))
if loginPeerRatePerS == 0 || err != nil {
loginPeerRatePerS = 200
}
log.WithContext(ctx).Infof("login rate limit set to %d/min", loginRatePerS)
syncRatePerS, err := strconv.Atoi(os.Getenv("NB_SYNC_RATE_PER_M"))
if syncRatePerS == 0 || err != nil {
syncRatePerS = 20000
}
log.WithContext(ctx).Infof("sync rate limit set to %d/min", syncRatePerS)
syncBurst, err := strconv.Atoi(os.Getenv("NB_SYNC_BURST"))
if syncBurst == 0 || err != nil {
syncBurst = 30000
}
log.WithContext(ctx).Infof("sync burst limit set to %d", syncBurst)
return &GRPCServer{ return &GRPCServer{
wgKey: key, wgKey: key,
// peerKey -> event channel // peerKey -> event channel
@@ -87,6 +129,9 @@ func NewServer(
authManager: authManager, authManager: authManager,
appMetrics: appMetrics, appMetrics: appMetrics,
ephemeralManager: ephemeralManager, ephemeralManager: ephemeralManager,
syncLimiter: rate.NewLimiter(rate.Every(time.Minute/time.Duration(syncRatePerS)), syncBurst),
loginPeerLimit: rate.Every(multiplier / time.Duration(loginPeerRatePerS)),
loginPeerBooster: loginBurst,
}, nil }, nil
} }
@@ -128,11 +173,17 @@ func getRealIP(ctx context.Context) net.IP {
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account) // notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
reqStart := time.Now()
if s.appMetrics != nil { if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequest() s.appMetrics.GRPCMetrics().CountSyncRequest()
} }
if !s.syncLimiter.Allow() {
log.Warnf("sync rate limit exceeded for peer %s", req.WgPubKey)
return status.Errorf(codes.Internal, "temp rate limit reached")
}
reqStart := time.Now()
ctx := srv.Context() ctx := srv.Context()
syncReq := &proto.SyncRequest{} syncReq := &proto.SyncRequest{}
@@ -428,15 +479,39 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer. // In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
// In case of the successful registration login is also successful // In case of the successful registration login is also successful
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequest()
}
limiterIface, ok := s.loginLimiterStore.Load(req.WgPubKey)
if !ok {
// Create new limiter for this peer
newLimiter := rate.NewLimiter(s.loginPeerLimit, s.loginPeerBooster)
s.loginLimiterStore.Store(req.WgPubKey, newLimiter)
if !newLimiter.Allow() {
//time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100)))
log.WithContext(ctx).Warnf("rate limit exceeded for peer %s", req.WgPubKey)
return nil, fmt.Errorf("temp rate limit reached (new peer limit)")
}
} else {
// Use existing limiter for this peer
limiter := limiterIface.(*rate.Limiter)
if !limiter.Allow() {
//time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100)))
log.WithContext(ctx).Warnf("rate limit exceeded for peer %s", req.WgPubKey)
return nil, fmt.Errorf("temp rate limit reached (peer limit)")
}
}
reqStart := time.Now() reqStart := time.Now()
defer func() { defer func() {
if s.appMetrics != nil { if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart)) s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart))
} }
}() }()
if s.appMetrics != nil { //if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequest() // s.appMetrics.GRPCMetrics().CountLoginRequest()
} //}
realIP := getRealIP(ctx) realIP := getRealIP(ctx)
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String()) log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())

View File

@@ -60,6 +60,8 @@ components:
description: Account creator description: Account creator
type: string type: string
example: google-oauth2|277474792786460067937 example: google-oauth2|277474792786460067937
onboarding:
$ref: '#/components/schemas/AccountOnboarding'
required: required:
- id - id
- settings - settings
@@ -67,6 +69,21 @@ components:
- domain_category - domain_category
- created_at - created_at
- created_by - created_by
- onboarding
AccountOnboarding:
type: object
properties:
signup_form_pending:
description: Indicates whether the account signup form is pending
type: boolean
example: true
onboarding_flow_pending:
description: Indicates whether the account onboarding flow is pending
type: boolean
example: false
required:
- signup_form_pending
- onboarding_flow_pending
AccountSettings: AccountSettings:
type: object type: object
properties: properties:
@@ -153,6 +170,8 @@ components:
properties: properties:
settings: settings:
$ref: '#/components/schemas/AccountSettings' $ref: '#/components/schemas/AccountSettings'
onboarding:
$ref: '#/components/schemas/AccountOnboarding'
required: required:
- settings - settings
User: User:

View File

@@ -251,6 +251,7 @@ type Account struct {
// Id Account ID // Id Account ID
Id string `json:"id"` Id string `json:"id"`
Onboarding AccountOnboarding `json:"onboarding"`
Settings AccountSettings `json:"settings"` Settings AccountSettings `json:"settings"`
} }
@@ -266,8 +267,18 @@ type AccountExtraSettings struct {
PeerApprovalEnabled bool `json:"peer_approval_enabled"` PeerApprovalEnabled bool `json:"peer_approval_enabled"`
} }
// AccountOnboarding defines model for AccountOnboarding.
type AccountOnboarding struct {
// OnboardingFlowPending Indicates whether the account onboarding flow is pending
OnboardingFlowPending bool `json:"onboarding_flow_pending"`
// SignupFormPending Indicates whether the account signup form is pending
SignupFormPending bool `json:"signup_form_pending"`
}
// AccountRequest defines model for AccountRequest. // AccountRequest defines model for AccountRequest.
type AccountRequest struct { type AccountRequest struct {
Onboarding *AccountOnboarding `json:"onboarding,omitempty"`
Settings AccountSettings `json:"settings"` Settings AccountSettings `json:"settings"`
} }

View File

@@ -59,7 +59,13 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
return return
} }
resp := toAccountResponse(accountID, settings, meta) onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toAccountResponse(accountID, settings, meta, onboarding)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
} }
@@ -126,6 +132,20 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
} }
var onboarding *types.AccountOnboarding
if req.Onboarding != nil {
onboarding = &types.AccountOnboarding{
OnboardingFlowPending: req.Onboarding.OnboardingFlowPending,
SignupFormPending: req.Onboarding.SignupFormPending,
}
}
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
@@ -138,7 +158,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
return return
} }
resp := toAccountResponse(accountID, updatedSettings, meta) resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding)
util.WriteJSONObject(r.Context(), w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }
@@ -167,7 +187,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta) *api.Account { func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account {
jwtAllowGroups := settings.JWTAllowGroups jwtAllowGroups := settings.JWTAllowGroups
if jwtAllowGroups == nil { if jwtAllowGroups == nil {
jwtAllowGroups = []string{} jwtAllowGroups = []string{}
@@ -188,6 +208,11 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
DnsDomain: &settings.DNSDomain, DnsDomain: &settings.DNSDomain,
} }
apiOnboarding := api.AccountOnboarding{
OnboardingFlowPending: onboarding.OnboardingFlowPending,
SignupFormPending: onboarding.SignupFormPending,
}
if settings.Extra != nil { if settings.Extra != nil {
apiSettings.Extra = &api.AccountExtraSettings{ apiSettings.Extra = &api.AccountExtraSettings{
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled, PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
@@ -203,5 +228,6 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
CreatedBy: meta.CreatedBy, CreatedBy: meta.CreatedBy,
Domain: meta.Domain, Domain: meta.Domain,
DomainCategory: meta.DomainCategory, DomainCategory: meta.DomainCategory,
Onboarding: apiOnboarding,
} }
} }

View File

@@ -54,6 +54,18 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
GetAccountMetaFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) { GetAccountMetaFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) {
return account.GetMeta(), nil return account.GetMeta(), nil
}, },
GetAccountOnboardingFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
return &types.AccountOnboarding{
OnboardingFlowPending: true,
SignupFormPending: true,
}, nil
},
UpdateAccountOnboardingFunc: func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
return &types.AccountOnboarding{
OnboardingFlowPending: true,
SignupFormPending: true,
}, nil
},
}, },
settingsManager: settingsMockManager, settingsManager: settingsMockManager,
} }
@@ -117,7 +129,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
expectedBody: true, expectedBody: true,
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID, requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true}}"), requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedSettings: api.AccountSettings{ expectedSettings: api.AccountSettings{
PeerLoginExpiration: 15552000, PeerLoginExpiration: 15552000,
@@ -139,7 +151,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
expectedBody: true, expectedBody: true,
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID, requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"), requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedSettings: api.AccountSettings{ expectedSettings: api.AccountSettings{
PeerLoginExpiration: 15552000, PeerLoginExpiration: 15552000,
@@ -161,7 +173,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
expectedBody: true, expectedBody: true,
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID, requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true}}"), requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedSettings: api.AccountSettings{ expectedSettings: api.AccountSettings{
PeerLoginExpiration: 554400, PeerLoginExpiration: 554400,
@@ -178,12 +190,34 @@ func TestAccounts_AccountsHandler(t *testing.T) {
expectedArray: false, expectedArray: false,
expectedID: accountID, expectedID: accountID,
}, },
{
name: "PutAccount OK without onboarding",
expectedBody: true,
requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"),
expectedStatus: http.StatusOK,
expectedSettings: api.AccountSettings{
PeerLoginExpiration: 15552000,
PeerLoginExpirationEnabled: false,
GroupsPropagationEnabled: br(false),
JwtGroupsClaimName: sr("roles"),
JwtGroupsEnabled: br(true),
JwtAllowGroups: &[]string{"test"},
RegularUsersViewBlocked: true,
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
},
expectedArray: false,
expectedID: accountID,
},
{ {
name: "Update account failure with high peer_login_expiration more than 180 days", name: "Update account failure with high peer_login_expiration more than 180 days",
expectedBody: true, expectedBody: true,
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID, requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552001,\"peer_login_expiration_enabled\": true}}"), requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552001,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
expectedStatus: http.StatusUnprocessableEntity, expectedStatus: http.StatusUnprocessableEntity,
expectedArray: false, expectedArray: false,
}, },
@@ -192,7 +226,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
expectedBody: true, expectedBody: true,
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID, requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 3599,\"peer_login_expiration_enabled\": true}}"), requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 3599,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
expectedStatus: http.StatusUnprocessableEntity, expectedStatus: http.StatusUnprocessableEntity,
expectedArray: false, expectedArray: false,
}, },

View File

@@ -440,7 +440,11 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()). GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
AnyTimes(). AnyTimes().
Return(&types.Settings{}, nil) Return(&types.Settings{}, nil)
settingsMockManager.
EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",

View File

@@ -117,7 +117,8 @@ type MockAccountManager struct {
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error)
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
} }
@@ -125,6 +126,10 @@ func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID
// do nothing // do nothing
} }
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
// do nothing
}
func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
if am.DeleteSetupKeyFunc != nil { if am.DeleteSetupKeyFunc != nil {
return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID) return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID)
@@ -814,6 +819,22 @@ func (am *MockAccountManager) GetAccountMeta(ctx context.Context, accountID stri
return nil, status.Errorf(codes.Unimplemented, "method GetAccountMeta is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetAccountMeta is not implemented")
} }
// GetAccountOnboarding mocks GetAccountOnboarding of the AccountManager interface
func (am *MockAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
if am.GetAccountOnboardingFunc != nil {
return am.GetAccountOnboardingFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountOnboarding is not implemented")
}
// UpdateAccountOnboarding mocks UpdateAccountOnboarding of the AccountManager interface
func (am *MockAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID string, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
if am.UpdateAccountOnboardingFunc != nil {
return am.UpdateAccountOnboardingFunc(ctx, accountID, userID, onboarding)
}
return nil, status.Errorf(codes.Unimplemented, "method UpdateAccountOnboarding is not implemented")
}
// GetUserByID mocks GetUserByID of the AccountManager interface // GetUserByID mocks GetUserByID of the AccountManager interface
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
if am.GetUserByIDFunc != nil { if am.GetUserByIDFunc != nil {

View File

@@ -778,6 +778,12 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.
EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }

View File

@@ -391,7 +391,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
storeEvent() storeEvent()
} }
if updateAccountPeers { if updateAccountPeers && userID != activity.SystemInitiator {
am.BufferUpdateAccountPeers(ctx, accountID) am.BufferUpdateAccountPeers(ctx, accountID)
} }
@@ -1169,7 +1169,26 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
return return
} }
if am.metrics != nil {
globalStart := time.Now() globalStart := time.Now()
defer func() {
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
}()
}
peersToUpdate := []*nbpeer.Peer{}
for _, peer := range account.Peers {
if !am.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
continue
}
peersToUpdate = append(peersToUpdate, peer)
}
if len(peersToUpdate) == 0 {
return
}
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil { if err != nil {
@@ -1192,12 +1211,13 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
return return
} }
for _, peer := range account.Peers { extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID)
if !am.peersUpdateManager.HasChannel(peer.ID) { if err != nil {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err)
continue return
} }
for _, peer := range peersToUpdate {
wg.Add(1) wg.Add(1)
semaphore <- struct{}{} semaphore <- struct{}{}
go func(p *nbpeer.Peer) { go func(p *nbpeer.Peer) {
@@ -1226,12 +1246,6 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
} }
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start)) am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err)
return
}
start = time.Now() start = time.Now()
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting)
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
@@ -1240,26 +1254,27 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
}(peer) }(peer)
} }
//
wg.Wait() wg.Wait()
if am.metrics != nil {
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
}
} }
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{}) mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{})
lock := mu.(*sync.Mutex) lock := mu.(*sync.Mutex)
log.WithContext(ctx).Debugf("try to BufferUpdateAccountPeers for account %s", accountID)
if !lock.TryLock() { if !lock.TryLock() {
log.WithContext(ctx).Debugf("BufferUpdateAccountPeers for an account %s locked - returning", accountID)
return return
} }
go func() { go func() {
time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load())) // time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
lock.Unlock() defer lock.Unlock()
log.WithContext(ctx).Debugf("BufferUpdateAccountPeers for an account %s - in progress", accountID)
tn := time.Now()
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
log.WithContext(ctx).Debugf("BufferUpdateAccountPeers for an account %s - took %dms", accountID, time.Since(tn).Milliseconds())
}() }()
} }

View File

@@ -1340,6 +1340,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.
EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
@@ -1544,6 +1549,11 @@ func Test_LoginPeer(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.
EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)

View File

@@ -90,6 +90,11 @@ func NewAccountNotFoundError(accountKey string) error {
return Errorf(NotFound, "account not found: %s", accountKey) return Errorf(NotFound, "account not found: %s", accountKey)
} }
// NewAccountOnboardingNotFoundError creates a new Error with NotFound type for a missing account onboarding
func NewAccountOnboardingNotFoundError(accountKey string) error {
return Errorf(NotFound, "account onboarding not found: %s", accountKey)
}
// NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account // NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account
func NewPeerNotPartOfAccountError() error { func NewPeerNotPartOfAccountError() error {
return Errorf(PermissionDenied, "peer is not part of this account") return Errorf(PermissionDenied, "peer is not part of this account")

View File

@@ -99,7 +99,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{},
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("auto migrate: %w", err) return nil, fmt.Errorf("auto migrate: %w", err)
@@ -725,6 +725,32 @@ func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStren
return &accountMeta, nil return &accountMeta, nil
} }
// GetAccountOnboarding retrieves the onboarding information for a specific account.
func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error) {
var accountOnboarding types.AccountOnboarding
result := s.db.Model(&accountOnboarding).First(&accountOnboarding, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountOnboardingNotFoundError(accountID)
}
log.WithContext(ctx).Errorf("error when getting account onboarding %s from the store: %s", accountID, result.Error)
return nil, status.NewGetAccountFromStoreError(result.Error)
}
return &accountOnboarding, nil
}
// SaveAccountOnboarding updates the onboarding information for a specific account.
func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error {
result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(onboarding)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error)
return status.Errorf(status.Internal, "error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error)
}
return nil
}
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
start := time.Now() start := time.Now()
defer func() { defer func() {

View File

@@ -353,9 +353,16 @@ func TestSqlite_DeleteAccount(t *testing.T) {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
} }
o, err := store.GetAccountOnboarding(context.Background(), account.Id)
require.NoError(t, err)
require.Equal(t, o.AccountID, account.Id)
err = store.DeleteAccount(context.Background(), account) err = store.DeleteAccount(context.Background(), account)
require.NoError(t, err) require.NoError(t, err)
_, err = store.GetAccountOnboarding(context.Background(), account.Id)
require.Error(t, err, "expecting error after removing DeleteAccount when getting onboarding")
if len(store.GetAllAccounts(context.Background())) != 0 { if len(store.GetAllAccounts(context.Background())) != 0 {
t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()") t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()")
} }
@@ -413,12 +420,21 @@ func Test_GetAccount(t *testing.T) {
account, err := store.GetAccount(context.Background(), id) account, err := store.GetAccount(context.Background(), id)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, id, account.Id, "account id should match") require.Equal(t, id, account.Id, "account id should match")
require.Equal(t, false, account.Onboarding.OnboardingFlowPending)
id = "9439-34653001fc3b-bf1c8084-ba50-4ce7"
account, err = store.GetAccount(context.Background(), id)
require.NoError(t, err)
require.Equal(t, id, account.Id, "account id should match")
require.Equal(t, true, account.Onboarding.OnboardingFlowPending)
_, err = store.GetAccount(context.Background(), "non-existing-account") _, err = store.GetAccount(context.Background(), "non-existing-account")
assert.Error(t, err) assert.Error(t, err)
parsedErr, ok := status.FromError(err) parsedErr, ok := status.FromError(err)
require.True(t, ok) require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}) })
} }
@@ -2042,6 +2058,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
PeerInactivityExpirationEnabled: false, PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
}, },
Onboarding: types.AccountOnboarding{SignupFormPending: true, OnboardingFlowPending: true},
} }
if err := acc.AddAllGroup(false); err != nil { if err := acc.AddAllGroup(false); err != nil {
@@ -3386,6 +3403,63 @@ func TestSqlStore_GetAccountMeta(t *testing.T) {
require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC()) require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC())
} }
func TestSqlStore_GetAccountOnboarding(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "9439-34653001fc3b-bf1c8084-ba50-4ce7"
a, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
t.Logf("Onboarding: %+v", a.Onboarding)
err = store.SaveAccount(context.Background(), a)
require.NoError(t, err)
onboarding, err := store.GetAccountOnboarding(context.Background(), accountID)
require.NoError(t, err)
require.NotNil(t, onboarding)
require.Equal(t, accountID, onboarding.AccountID)
require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), onboarding.CreatedAt.UTC())
}
func TestSqlStore_SaveAccountOnboarding(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
t.Run("New onboarding should be saved correctly", func(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
onboarding := &types.AccountOnboarding{
AccountID: accountID,
SignupFormPending: true,
OnboardingFlowPending: true,
}
err = store.SaveAccountOnboarding(context.Background(), onboarding)
require.NoError(t, err)
savedOnboarding, err := store.GetAccountOnboarding(context.Background(), accountID)
require.NoError(t, err)
require.Equal(t, onboarding.SignupFormPending, savedOnboarding.SignupFormPending)
require.Equal(t, onboarding.OnboardingFlowPending, savedOnboarding.OnboardingFlowPending)
})
t.Run("Existing onboarding should be updated correctly", func(t *testing.T) {
accountID := "9439-34653001fc3b-bf1c8084-ba50-4ce7"
onboarding, err := store.GetAccountOnboarding(context.Background(), accountID)
require.NoError(t, err)
onboarding.OnboardingFlowPending = !onboarding.OnboardingFlowPending
onboarding.SignupFormPending = !onboarding.SignupFormPending
err = store.SaveAccountOnboarding(context.Background(), onboarding)
require.NoError(t, err)
savedOnboarding, err := store.GetAccountOnboarding(context.Background(), accountID)
require.NoError(t, err)
require.Equal(t, onboarding.SignupFormPending, savedOnboarding.SignupFormPending)
require.Equal(t, onboarding.OnboardingFlowPending, savedOnboarding.OnboardingFlowPending)
})
}
func TestSqlStore_GetAnyAccountID(t *testing.T) { func TestSqlStore_GetAnyAccountID(t *testing.T) {
t.Run("should return account ID when accounts exist", func(t *testing.T) { t.Run("should return account ID when accounts exist", func(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())

View File

@@ -52,6 +52,7 @@ type Store interface {
GetAllAccounts(ctx context.Context) []*types.Account GetAllAccounts(ctx context.Context) []*types.Account
GetAccount(ctx context.Context, accountID string) (*types.Account, error) GetAccount(ctx context.Context, accountID string) (*types.Account, error)
GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error)
GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error)
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error)
@@ -74,6 +75,7 @@ type Store interface {
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error
CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error)
SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)

View File

@@ -1,4 +1,5 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `account_onboardings` (`account_id` text, `created_at` datetime,`updated_at` datetime, `onboarding_flow_pending` numeric, `signup_form_pending` numeric, PRIMARY KEY (`account_id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
@@ -38,7 +39,8 @@ CREATE INDEX `idx_networks_id` ON `networks`(`id`);
CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`); CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO accounts VALUES('9439-34653001fc3b-bf1c8084-ba50-4ce7','90d6-0242ac120003-edafee4e-63fb-11ec','2024-10-02 16:01:38.210000+02:00','test2.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO account_onboardings VALUES('9439-34653001fc3b-bf1c8084-ba50-4ce7','2024-10-02 16:01:38.210000+02:00','2021-08-19 20:46:20.005936822+02:00',1,0);INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0); INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0);
INSERT INTO users VALUES('a23efe53-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','owner',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,''); INSERT INTO users VALUES('a23efe53-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','owner',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,'');
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,'');

View File

@@ -83,10 +83,10 @@ type Account struct {
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
// Settings is a dictionary of Account settings // Settings is a dictionary of Account settings
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"` Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"` NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"`
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
} }
// Subclass used in gorm to only load network and not whole account // Subclass used in gorm to only load network and not whole account
@@ -104,6 +104,20 @@ type AccountSettings struct {
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
} }
type AccountOnboarding struct {
AccountID string `gorm:"primaryKey"`
OnboardingFlowPending bool
SignupFormPending bool
CreatedAt time.Time
UpdatedAt time.Time
}
// IsEqual compares two AccountOnboarding objects and returns true if they are equal
func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
return o.OnboardingFlowPending == onboarding.OnboardingFlowPending &&
o.SignupFormPending == onboarding.SignupFormPending
}
// GetRoutesToSync returns the enabled routes for the peer ID and the routes // GetRoutesToSync returns the enabled routes for the peer ID and the routes
// from the ACL peers that have distribution groups associated with the peer ID. // from the ACL peers that have distribution groups associated with the peer ID.
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
@@ -866,6 +880,7 @@ func (a *Account) Copy() *Account {
Networks: nets, Networks: nets,
NetworkRouters: networkRouters, NetworkRouters: networkRouters,
NetworkResources: networkResources, NetworkResources: networkResources,
Onboarding: a.Onboarding,
} }
} }