mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Compare commits
30 Commits
feature/ne
...
add-accoun
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3945d2b170 | ||
|
|
0957defa54 | ||
|
|
0c6ab1de30 | ||
|
|
2e18d77d40 | ||
|
|
3a8a6fcb76 | ||
|
|
49f083a372 | ||
|
|
7d9ca73f6c | ||
|
|
470b80c1b8 | ||
|
|
ad0b78a7ac | ||
|
|
7aa2ca87f2 | ||
|
|
4c58088311 | ||
|
|
bcccd65008 | ||
|
|
1ffc8933de | ||
|
|
ad22e9eea1 | ||
|
|
d9402168ad | ||
|
|
dbdef04b9e | ||
|
|
d806fc4a03 | ||
|
|
7a5edb3894 | ||
|
|
2dc230ab9a | ||
|
|
29cbfe8467 | ||
|
|
6ce8643368 | ||
|
|
432dc42bf5 | ||
|
|
07d1ad35fc | ||
|
|
ef6cd36f1a | ||
|
|
c1c71b6d39 | ||
|
|
0480507a10 | ||
|
|
34ac4e4b5a | ||
|
|
52ff9d9602 | ||
|
|
1b73fae46e | ||
|
|
d897365abc |
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.18"
|
||||
SIGN_PIPE_VER: "v0.0.20"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
@@ -134,6 +134,7 @@ jobs:
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
||||
|
||||
run: |
|
||||
set -x
|
||||
@@ -180,6 +181,7 @@ jobs:
|
||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||
grep DisablePromptLogin management.json | grep 'true'
|
||||
grep LoginFlag management.json | grep 0
|
||||
grep DisableDefaultPolicy management.json | grep "$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY"
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -102,8 +102,13 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.
|
||||
EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -1476,7 +1476,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -28,7 +28,10 @@ func (n Nexthop) String() string {
|
||||
if n.Intf == nil {
|
||||
return n.IP.String()
|
||||
}
|
||||
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
|
||||
if n.IP.IsValid() {
|
||||
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
|
||||
}
|
||||
return fmt.Sprintf("no-ip @ %d (%s)", n.Intf.Index, n.Intf.Name)
|
||||
}
|
||||
|
||||
type wgIface interface {
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
<Wix
|
||||
xmlns="http://wixtoolset.org/schemas/v4/wxs">
|
||||
xmlns="http://wixtoolset.org/schemas/v4/wxs"
|
||||
xmlns:util="http://wixtoolset.org/schemas/v4/wxs/util">
|
||||
<Package Name="NetBird" Version="$(env.NETBIRD_VERSION)" Manufacturer="NetBird GmbH" Language="1033" UpgradeCode="6456ec4e-3ad6-4b9b-a2be-98e81cb21ccf"
|
||||
InstallerVersion="500" Compressed="yes" Codepage="utf-8" >
|
||||
|
||||
|
||||
<MediaTemplate EmbedCab="yes" />
|
||||
|
||||
<Feature Id="NetbirdFeature" Title="Netbird" Level="1">
|
||||
@@ -46,29 +48,10 @@
|
||||
<ComponentRef Id="NetbirdFiles" />
|
||||
</ComponentGroup>
|
||||
|
||||
<Property Id="cmd" Value="cmd.exe"/>
|
||||
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
|
||||
<util:CloseApplication Id="CloseNetBirdUI" CloseMessage="no" Target="netbird-ui.exe" RebootPrompt="no" />
|
||||
|
||||
<CustomAction Id="KillDaemon"
|
||||
ExeCommand='/c "taskkill /im netbird.exe"'
|
||||
Execute="deferred"
|
||||
Property="cmd"
|
||||
Impersonate="no"
|
||||
Return="ignore"
|
||||
/>
|
||||
|
||||
<CustomAction Id="KillUI"
|
||||
ExeCommand='/c "taskkill /im netbird-ui.exe"'
|
||||
Execute="deferred"
|
||||
Property="cmd"
|
||||
Impersonate="no"
|
||||
Return="ignore"
|
||||
/>
|
||||
|
||||
<InstallExecuteSequence>
|
||||
<!-- For Uninstallation -->
|
||||
<Custom Action="KillDaemon" Before="RemoveFiles" Condition="Installed"/>
|
||||
<Custom Action="KillUI" After="KillDaemon" Condition="Installed"/>
|
||||
</InstallExecuteSequence>
|
||||
|
||||
<!-- Icons -->
|
||||
<Icon Id="NetbirdIcon" SourceFile=".\client\ui\assets\netbird.ico" />
|
||||
|
||||
@@ -206,7 +206,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -879,7 +879,7 @@ func (s *serviceClient) onUpdateAvailable() {
|
||||
func (s *serviceClient) onSessionExpire() {
|
||||
s.sendNotification = true
|
||||
if s.sendNotification {
|
||||
s.eventHandler.runSelfCommand(s.ctx, "login-url", "true")
|
||||
go s.eventHandler.runSelfCommand(s.ctx, "login-url", "true")
|
||||
s.sendNotification = false
|
||||
}
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -105,6 +105,7 @@ require (
|
||||
golang.org/x/oauth2 v0.24.0
|
||||
golang.org/x/sync v0.13.0
|
||||
golang.org/x/term v0.31.0
|
||||
golang.org/x/time v0.5.0
|
||||
google.golang.org/api v0.177.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
@@ -240,7 +241,6 @@ require (
|
||||
golang.org/x/image v0.18.0 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
golang.org/x/text v0.24.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
|
||||
|
||||
@@ -15,6 +15,7 @@ NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAI
|
||||
NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN
|
||||
NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted}
|
||||
NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=${NETBIRD_MGMT_IDP_SIGNKEY_REFRESH:-false}
|
||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=${NETBIRD_MGMT_DISABLE_DEFAULT_POLICY:-false}
|
||||
|
||||
# Signal
|
||||
NETBIRD_SIGNAL_PROTOCOL="http"
|
||||
@@ -60,7 +61,7 @@ NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken}
|
||||
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"}
|
||||
NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false}
|
||||
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=${NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN:-false}
|
||||
NETBIRD_AUTH_PKCE_LOGIN_FLAG=${NETBIRD_AUTH_PKCE_LOGIN_FLAG:-1}
|
||||
NETBIRD_AUTH_PKCE_LOGIN_FLAG=${NETBIRD_AUTH_PKCE_LOGIN_FLAG:-0}
|
||||
NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
|
||||
|
||||
# Dashboard
|
||||
@@ -139,3 +140,4 @@ export NETBIRD_RELAY_PORT
|
||||
export NETBIRD_RELAY_ENDPOINT
|
||||
export NETBIRD_RELAY_AUTH_SECRET
|
||||
export NETBIRD_RELAY_TAG
|
||||
export NETBIRD_MGMT_DISABLE_DEFAULT_POLICY
|
||||
|
||||
@@ -791,7 +791,6 @@ services:
|
||||
- '443:443'
|
||||
- '443:443/udp'
|
||||
- '80:80'
|
||||
- '8080:8080'
|
||||
volumes:
|
||||
- netbird_caddy_data:/data
|
||||
- ./Caddyfile:/etc/caddy/Caddyfile
|
||||
|
||||
@@ -38,6 +38,7 @@
|
||||
"0.0.0.0/0"
|
||||
]
|
||||
},
|
||||
"DisableDefaultPolicy": $NETBIRD_MGMT_DISABLE_DEFAULT_POLICY,
|
||||
"Datadir": "",
|
||||
"DataStoreEncryptionKey": "$NETBIRD_DATASTORE_ENC_KEY",
|
||||
"StoreConfig": {
|
||||
|
||||
@@ -92,7 +92,8 @@ NETBIRD_LETSENCRYPT_EMAIL=""
|
||||
NETBIRD_DISABLE_ANONYMOUS_METRICS=false
|
||||
# DNS DOMAIN configures the domain name used for peer resolution. By default it is netbird.selfhosted
|
||||
NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted
|
||||
|
||||
# Disable default all-to-all policy for new accounts
|
||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=false
|
||||
# -------------------------------------------
|
||||
# Relay settings
|
||||
# -------------------------------------------
|
||||
|
||||
@@ -29,3 +29,4 @@ NETBIRD_TURN_EXTERNAL_IP=1.2.3.4
|
||||
NETBIRD_RELAY_PORT=33445
|
||||
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=true
|
||||
NETBIRD_AUTH_PKCE_LOGIN_FLAG=0
|
||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY
|
||||
|
||||
@@ -87,6 +87,12 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.
|
||||
EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
permissionsManagerMock.
|
||||
EXPECT().
|
||||
@@ -100,7 +106,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
Return(true, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -215,7 +215,7 @@ var (
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
proxyController := integrations.NewController(store)
|
||||
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager)
|
||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager, config.DisableDefaultPolicy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build default manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -102,6 +102,8 @@ type DefaultAccountManager struct {
|
||||
|
||||
accountUpdateLocks sync.Map
|
||||
updateAccountPeersBufferInterval atomic.Int64
|
||||
|
||||
disableDefaultPolicy bool
|
||||
}
|
||||
|
||||
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
|
||||
@@ -170,6 +172,7 @@ func BuildManager(
|
||||
proxyController port_forwarding.Controller,
|
||||
settingsManager settings.Manager,
|
||||
permissionsManager permissions.Manager,
|
||||
disableDefaultPolicy bool,
|
||||
) (*DefaultAccountManager, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
@@ -195,6 +198,7 @@ func BuildManager(
|
||||
proxyController: proxyController,
|
||||
settingsManager: settingsManager,
|
||||
permissionsManager: permissionsManager,
|
||||
disableDefaultPolicy: disableDefaultPolicy,
|
||||
}
|
||||
|
||||
am.startWarmup(ctx)
|
||||
@@ -543,7 +547,7 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain
|
||||
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
|
||||
continue
|
||||
case statusErr.Type() == status.NotFound:
|
||||
newAccount := newAccountWithId(ctx, accountId, userID, domain)
|
||||
newAccount := newAccountWithId(ctx, accountId, userID, domain, am.disableDefaultPolicy)
|
||||
am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil)
|
||||
return newAccount, nil
|
||||
default:
|
||||
@@ -1188,6 +1192,71 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s
|
||||
return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// GetAccountOnboarding retrieves the onboarding information for a specific account.
|
||||
func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
|
||||
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
|
||||
log.Errorf("failed to get account onboarding for accountssssssss %s: %v", accountID, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if onboarding == nil {
|
||||
onboarding = &types.AccountOnboarding{
|
||||
AccountID: accountID,
|
||||
}
|
||||
}
|
||||
|
||||
return onboarding, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
|
||||
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
|
||||
return nil, fmt.Errorf("failed to get account onboarding: %w", err)
|
||||
}
|
||||
|
||||
if oldOnboarding == nil {
|
||||
oldOnboarding = &types.AccountOnboarding{
|
||||
AccountID: accountID,
|
||||
}
|
||||
}
|
||||
|
||||
if newOnboarding == nil {
|
||||
return oldOnboarding, nil
|
||||
}
|
||||
|
||||
if oldOnboarding.IsEqual(*newOnboarding) {
|
||||
log.WithContext(ctx).Debugf("no changes in onboarding for account %s", accountID)
|
||||
return oldOnboarding, nil
|
||||
}
|
||||
|
||||
newOnboarding.AccountID = accountID
|
||||
err = am.Store.SaveAccountOnboarding(ctx, newOnboarding)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update account onboarding: %w", err)
|
||||
}
|
||||
|
||||
return newOnboarding, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
if userAuth.UserId == "" {
|
||||
return "", "", errors.New(emptyUserID)
|
||||
@@ -1688,7 +1757,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
|
||||
}
|
||||
|
||||
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account {
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account {
|
||||
log.WithContext(ctx).Debugf("creating new account")
|
||||
|
||||
network := types.NewNetwork()
|
||||
@@ -1729,9 +1798,13 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
|
||||
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
|
||||
RoutingPeerDNSResolutionEnabled: true,
|
||||
},
|
||||
Onboarding: types.AccountOnboarding{
|
||||
OnboardingFlowPending: true,
|
||||
SignupFormPending: true,
|
||||
},
|
||||
}
|
||||
|
||||
if err := acc.AddAllGroup(); err != nil {
|
||||
if err := acc.AddAllGroup(disableDefaultPolicy); err != nil {
|
||||
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
|
||||
}
|
||||
return acc
|
||||
@@ -1833,7 +1906,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
|
||||
},
|
||||
}
|
||||
|
||||
if err := newAccount.AddAllGroup(); err != nil {
|
||||
if err := newAccount.AddAllGroup(am.disableDefaultPolicy); err != nil {
|
||||
return nil, false, status.Errorf(status.Internal, "failed to add all group to new account by private domain")
|
||||
}
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ type Manager interface {
|
||||
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
|
||||
GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
||||
GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error)
|
||||
GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error)
|
||||
AccountExists(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||
GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
@@ -89,6 +90,7 @@ type Manager interface {
|
||||
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error
|
||||
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
GetAllConnectedPeers() (map[string]struct{}, error)
|
||||
@@ -110,6 +112,7 @@ type Manager interface {
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
||||
UpdateAccountPeers(ctx context.Context, accountID string)
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string)
|
||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||
GetStore() store.Store
|
||||
|
||||
@@ -373,7 +373,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, testCase := range tt {
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io", false)
|
||||
account.UpdateSettings(&testCase.accountSettings)
|
||||
account.Network = network
|
||||
account.Peers = testCase.peers
|
||||
@@ -398,7 +398,7 @@ func TestNewAccount(t *testing.T) {
|
||||
domain := "netbird.io"
|
||||
userId := "account_creator"
|
||||
accountID := "account_id"
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain, false)
|
||||
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
|
||||
}
|
||||
|
||||
@@ -640,7 +640,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain)
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain, false)
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
|
||||
@@ -793,7 +793,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
}
|
||||
|
||||
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) {
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2879,7 +2879,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3440,3 +3440,74 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("should return account onboarding when onboarding exist", func(t *testing.T) {
|
||||
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, onboarding)
|
||||
assert.Equal(t, account.Id, onboarding.AccountID)
|
||||
assert.Equal(t, true, onboarding.OnboardingFlowPending)
|
||||
assert.Equal(t, true, onboarding.SignupFormPending)
|
||||
if onboarding.UpdatedAt.IsZero() {
|
||||
t.Errorf("Onboarding was not retrieved from the store")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should return account onboarding when onboard don't exist", func(t *testing.T) {
|
||||
account.Id = "with-zero-onboarding"
|
||||
account.Onboarding = types.AccountOnboarding{}
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, onboarding)
|
||||
_, err = manager.Store.GetAccountOnboarding(context.Background(), account.Id)
|
||||
require.Error(t, err, "should return error when onboarding is not set")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
onboarding := &types.AccountOnboarding{
|
||||
OnboardingFlowPending: true,
|
||||
SignupFormPending: true,
|
||||
}
|
||||
|
||||
t.Run("update onboarding with no change", func(t *testing.T) {
|
||||
updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending)
|
||||
assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending)
|
||||
if updated.UpdatedAt.IsZero() {
|
||||
t.Errorf("Onboarding was updated in the store")
|
||||
}
|
||||
})
|
||||
|
||||
onboarding.OnboardingFlowPending = false
|
||||
onboarding.SignupFormPending = false
|
||||
|
||||
t.Run("update onboarding", func(t *testing.T) {
|
||||
updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending)
|
||||
assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending)
|
||||
})
|
||||
|
||||
t.Run("update onboarding with no onboarding", func(t *testing.T) {
|
||||
_, err = manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, nil)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -216,8 +216,10 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
// return empty extra settings for expected calls to UpdateAccountPeers
|
||||
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
}
|
||||
|
||||
func createDNSStore(t *testing.T) (store.Store, error) {
|
||||
@@ -267,7 +269,7 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
|
||||
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain, false)
|
||||
|
||||
account.Users[dnsRegularUserID] = &types.User{
|
||||
Id: dnsRegularUserID,
|
||||
|
||||
@@ -5,16 +5,20 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
const (
|
||||
ephemeralLifeTime = 10 * time.Minute
|
||||
// cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure.
|
||||
cleanupWindow = 1 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -41,6 +45,9 @@ type EphemeralManager struct {
|
||||
tailPeer *ephemeralPeer
|
||||
peersLock sync.Mutex
|
||||
timer *time.Timer
|
||||
|
||||
lifeTime time.Duration
|
||||
cleanupWindow time.Duration
|
||||
}
|
||||
|
||||
// NewEphemeralManager instantiate new EphemeralManager
|
||||
@@ -48,6 +55,9 @@ func NewEphemeralManager(store store.Store, accountManager nbAccount.Manager) *E
|
||||
return &EphemeralManager{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
|
||||
lifeTime: ephemeralLifeTime,
|
||||
cleanupWindow: cleanupWindow,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,7 +70,7 @@ func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) {
|
||||
|
||||
e.loadEphemeralPeers(ctx)
|
||||
if e.headPeer != nil {
|
||||
e.timer = time.AfterFunc(ephemeralLifeTime, func() {
|
||||
e.timer = time.AfterFunc(e.lifeTime, func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
}
|
||||
@@ -113,9 +123,13 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
||||
return
|
||||
}
|
||||
|
||||
e.addPeer(peer.AccountID, peer.ID, newDeadLine())
|
||||
e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
|
||||
if e.timer == nil {
|
||||
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
|
||||
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
|
||||
if delay < 0 {
|
||||
delay = 0
|
||||
}
|
||||
e.timer = time.AfterFunc(delay, func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
}
|
||||
@@ -128,7 +142,7 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
t := newDeadLine()
|
||||
t := e.newDeadLine()
|
||||
for _, p := range peers {
|
||||
e.addPeer(p.AccountID, p.ID, t)
|
||||
}
|
||||
@@ -138,6 +152,9 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
||||
|
||||
func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
log.Tracef("on ephemeral cleanup")
|
||||
reqID := uuid.New().String()
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
|
||||
deletePeers := make(map[string]*ephemeralPeer)
|
||||
|
||||
e.peersLock.Lock()
|
||||
@@ -155,7 +172,11 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
}
|
||||
|
||||
if e.headPeer != nil {
|
||||
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
|
||||
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
|
||||
if delay < 0 {
|
||||
delay = 0
|
||||
}
|
||||
e.timer = time.AfterFunc(delay, func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
} else {
|
||||
@@ -164,13 +185,21 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
|
||||
e.peersLock.Unlock()
|
||||
|
||||
bufferAccountCall := make(map[string]struct{})
|
||||
|
||||
for id, p := range deletePeers {
|
||||
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
|
||||
err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
|
||||
} else {
|
||||
bufferAccountCall[p.accountID] = struct{}{}
|
||||
}
|
||||
}
|
||||
for accountID := range bufferAccountCall {
|
||||
log.WithContext(ctx).Debugf("ephemeral - buffer update account peers for account: %s", accountID)
|
||||
e.accountManager.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
|
||||
@@ -223,6 +252,6 @@ func (e *EphemeralManager) isPeerOnList(id string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func newDeadLine() time.Time {
|
||||
return timeNow().Add(ephemeralLifeTime)
|
||||
func (e *EphemeralManager) newDeadLine() time.Time {
|
||||
return timeNow().Add(e.lifeTime)
|
||||
}
|
||||
|
||||
@@ -3,9 +3,12 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -27,28 +30,65 @@ func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ store.LockingStren
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
type MocAccountManager struct {
|
||||
type MockAccountManager struct {
|
||||
mu sync.Mutex
|
||||
nbAccount.Manager
|
||||
store *MockStore
|
||||
store *MockStore
|
||||
deletePeerCalls int
|
||||
bufferUpdateCalls map[string]int
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
|
||||
func (a *MockAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.deletePeerCalls++
|
||||
if a.wg != nil {
|
||||
a.wg.Done()
|
||||
}
|
||||
delete(a.store.account.Peers, peerID)
|
||||
return nil //nolint:nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a MocAccountManager) GetStore() store.Store {
|
||||
func (a *MockAccountManager) GetDeletePeerCalls() int {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
return a.deletePeerCalls
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if a.bufferUpdateCalls == nil {
|
||||
a.bufferUpdateCalls = make(map[string]int)
|
||||
}
|
||||
a.bufferUpdateCalls[accountID]++
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) GetBufferUpdateCalls(accountID string) int {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if a.bufferUpdateCalls == nil {
|
||||
return 0
|
||||
}
|
||||
return a.bufferUpdateCalls[accountID]
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) GetStore() store.Store {
|
||||
return a.store
|
||||
}
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
timeNow = time.Now
|
||||
})
|
||||
startTime := time.Now()
|
||||
timeNow = func() time.Time {
|
||||
return startTime
|
||||
}
|
||||
|
||||
store := &MockStore{}
|
||||
am := MocAccountManager{
|
||||
am := MockAccountManager{
|
||||
store: store,
|
||||
}
|
||||
|
||||
@@ -56,7 +96,7 @@ func TestNewManager(t *testing.T) {
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr := NewEphemeralManager(store, &am)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
||||
mgr.cleanup(context.Background())
|
||||
@@ -67,13 +107,16 @@ func TestNewManager(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewManagerPeerConnected(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
timeNow = time.Now
|
||||
})
|
||||
startTime := time.Now()
|
||||
timeNow = func() time.Time {
|
||||
return startTime
|
||||
}
|
||||
|
||||
store := &MockStore{}
|
||||
am := MocAccountManager{
|
||||
am := MockAccountManager{
|
||||
store: store,
|
||||
}
|
||||
|
||||
@@ -81,7 +124,7 @@ func TestNewManagerPeerConnected(t *testing.T) {
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr := NewEphemeralManager(store, &am)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
||||
|
||||
@@ -95,13 +138,16 @@ func TestNewManagerPeerConnected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
timeNow = time.Now
|
||||
})
|
||||
startTime := time.Now()
|
||||
timeNow = func() time.Time {
|
||||
return startTime
|
||||
}
|
||||
|
||||
store := &MockStore{}
|
||||
am := MocAccountManager{
|
||||
am := MockAccountManager{
|
||||
store: store,
|
||||
}
|
||||
|
||||
@@ -109,7 +155,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr := NewEphemeralManager(store, &am)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
for _, v := range store.account.Peers {
|
||||
mgr.OnPeerConnected(context.Background(), v)
|
||||
@@ -126,8 +172,38 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
|
||||
const (
|
||||
ephemeralPeers = 10
|
||||
testLifeTime = 1 * time.Second
|
||||
testCleanupWindow = 100 * time.Millisecond
|
||||
)
|
||||
mockStore := &MockStore{}
|
||||
mockAM := &MockAccountManager{
|
||||
store: mockStore,
|
||||
}
|
||||
mockAM.wg = &sync.WaitGroup{}
|
||||
mockAM.wg.Add(ephemeralPeers)
|
||||
mgr := NewEphemeralManager(mockStore, mockAM)
|
||||
mgr.lifeTime = testLifeTime
|
||||
mgr.cleanupWindow = testCleanupWindow
|
||||
|
||||
account := newAccountWithId(context.Background(), "account", "", "", false)
|
||||
mockStore.account = account
|
||||
for i := range ephemeralPeers {
|
||||
p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true}
|
||||
mockStore.account.Peers[p.ID] = p
|
||||
time.Sleep(testCleanupWindow / ephemeralPeers)
|
||||
mgr.OnPeerDisconnected(context.Background(), p)
|
||||
}
|
||||
mockAM.wg.Wait()
|
||||
assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime")
|
||||
assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once")
|
||||
assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted only the first peer")
|
||||
}
|
||||
|
||||
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
|
||||
store.account = newAccountWithId(context.Background(), "my account", "", "")
|
||||
store.account = newAccountWithId(context.Background(), "my account", "", "", false)
|
||||
|
||||
for i := 0; i < numberOfPeers; i++ {
|
||||
peerId := fmt.Sprintf("peer_%d", i)
|
||||
|
||||
@@ -369,7 +369,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
|
||||
Id: "example user",
|
||||
AutoGroups: []string{groupForUsers.ID},
|
||||
}
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
|
||||
account.Routes[routeResource.ID] = routeResource
|
||||
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
|
||||
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -13,6 +15,7 @@ import (
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/time/rate"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/peer"
|
||||
@@ -47,6 +50,10 @@ type GRPCServer struct {
|
||||
ephemeralManager *EphemeralManager
|
||||
peerLocks sync.Map
|
||||
authManager auth.Manager
|
||||
syncLimiter *rate.Limiter
|
||||
loginLimiterStore sync.Map
|
||||
loginPeerBooster int
|
||||
loginPeerLimit rate.Limit
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@@ -76,6 +83,41 @@ func NewServer(
|
||||
}
|
||||
}
|
||||
|
||||
multiplier := time.Minute
|
||||
d, e := time.ParseDuration(os.Getenv("NB_LOGIN_RATE"))
|
||||
if e == nil {
|
||||
multiplier = d
|
||||
}
|
||||
|
||||
loginRatePerS, err := strconv.Atoi(os.Getenv("NB_LOGIN_RATE_PER_M"))
|
||||
if loginRatePerS == 0 || err != nil {
|
||||
loginRatePerS = 200
|
||||
}
|
||||
|
||||
loginBurst, err := strconv.Atoi(os.Getenv("NB_LOGIN_BURST"))
|
||||
if loginBurst == 0 || err != nil {
|
||||
loginBurst = 200
|
||||
}
|
||||
log.WithContext(ctx).Infof("login burst limit set to %d", loginBurst)
|
||||
|
||||
loginPeerRatePerS, err := strconv.Atoi(os.Getenv("NB_LOGIN_PEER_RATE_PER_M"))
|
||||
if loginPeerRatePerS == 0 || err != nil {
|
||||
loginPeerRatePerS = 200
|
||||
}
|
||||
log.WithContext(ctx).Infof("login rate limit set to %d/min", loginRatePerS)
|
||||
|
||||
syncRatePerS, err := strconv.Atoi(os.Getenv("NB_SYNC_RATE_PER_M"))
|
||||
if syncRatePerS == 0 || err != nil {
|
||||
syncRatePerS = 20000
|
||||
}
|
||||
log.WithContext(ctx).Infof("sync rate limit set to %d/min", syncRatePerS)
|
||||
|
||||
syncBurst, err := strconv.Atoi(os.Getenv("NB_SYNC_BURST"))
|
||||
if syncBurst == 0 || err != nil {
|
||||
syncBurst = 30000
|
||||
}
|
||||
log.WithContext(ctx).Infof("sync burst limit set to %d", syncBurst)
|
||||
|
||||
return &GRPCServer{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
@@ -87,6 +129,9 @@ func NewServer(
|
||||
authManager: authManager,
|
||||
appMetrics: appMetrics,
|
||||
ephemeralManager: ephemeralManager,
|
||||
syncLimiter: rate.NewLimiter(rate.Every(time.Minute/time.Duration(syncRatePerS)), syncBurst),
|
||||
loginPeerLimit: rate.Every(multiplier / time.Duration(loginPeerRatePerS)),
|
||||
loginPeerBooster: loginBurst,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -128,11 +173,17 @@ func getRealIP(ctx context.Context) net.IP {
|
||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
reqStart := time.Now()
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequest()
|
||||
}
|
||||
|
||||
if !s.syncLimiter.Allow() {
|
||||
log.Warnf("sync rate limit exceeded for peer %s", req.WgPubKey)
|
||||
return status.Errorf(codes.Internal, "temp rate limit reached")
|
||||
}
|
||||
|
||||
reqStart := time.Now()
|
||||
|
||||
ctx := srv.Context()
|
||||
|
||||
syncReq := &proto.SyncRequest{}
|
||||
@@ -428,15 +479,39 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa
|
||||
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
|
||||
// In case of the successful registration login is also successful
|
||||
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequest()
|
||||
}
|
||||
|
||||
limiterIface, ok := s.loginLimiterStore.Load(req.WgPubKey)
|
||||
if !ok {
|
||||
// Create new limiter for this peer
|
||||
newLimiter := rate.NewLimiter(s.loginPeerLimit, s.loginPeerBooster)
|
||||
s.loginLimiterStore.Store(req.WgPubKey, newLimiter)
|
||||
|
||||
if !newLimiter.Allow() {
|
||||
//time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100)))
|
||||
log.WithContext(ctx).Warnf("rate limit exceeded for peer %s", req.WgPubKey)
|
||||
return nil, fmt.Errorf("temp rate limit reached (new peer limit)")
|
||||
}
|
||||
} else {
|
||||
// Use existing limiter for this peer
|
||||
limiter := limiterIface.(*rate.Limiter)
|
||||
if !limiter.Allow() {
|
||||
//time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100)))
|
||||
log.WithContext(ctx).Warnf("rate limit exceeded for peer %s", req.WgPubKey)
|
||||
return nil, fmt.Errorf("temp rate limit reached (peer limit)")
|
||||
}
|
||||
}
|
||||
reqStart := time.Now()
|
||||
defer func() {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart))
|
||||
}
|
||||
}()
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequest()
|
||||
}
|
||||
//if s.appMetrics != nil {
|
||||
// s.appMetrics.GRPCMetrics().CountLoginRequest()
|
||||
//}
|
||||
realIP := getRealIP(ctx)
|
||||
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
|
||||
@@ -60,6 +60,8 @@ components:
|
||||
description: Account creator
|
||||
type: string
|
||||
example: google-oauth2|277474792786460067937
|
||||
onboarding:
|
||||
$ref: '#/components/schemas/AccountOnboarding'
|
||||
required:
|
||||
- id
|
||||
- settings
|
||||
@@ -67,6 +69,21 @@ components:
|
||||
- domain_category
|
||||
- created_at
|
||||
- created_by
|
||||
- onboarding
|
||||
AccountOnboarding:
|
||||
type: object
|
||||
properties:
|
||||
signup_form_pending:
|
||||
description: Indicates whether the account signup form is pending
|
||||
type: boolean
|
||||
example: true
|
||||
onboarding_flow_pending:
|
||||
description: Indicates whether the account onboarding flow is pending
|
||||
type: boolean
|
||||
example: false
|
||||
required:
|
||||
- signup_form_pending
|
||||
- onboarding_flow_pending
|
||||
AccountSettings:
|
||||
type: object
|
||||
properties:
|
||||
@@ -153,6 +170,8 @@ components:
|
||||
properties:
|
||||
settings:
|
||||
$ref: '#/components/schemas/AccountSettings'
|
||||
onboarding:
|
||||
$ref: '#/components/schemas/AccountOnboarding'
|
||||
required:
|
||||
- settings
|
||||
User:
|
||||
|
||||
@@ -250,8 +250,9 @@ type Account struct {
|
||||
DomainCategory string `json:"domain_category"`
|
||||
|
||||
// Id Account ID
|
||||
Id string `json:"id"`
|
||||
Settings AccountSettings `json:"settings"`
|
||||
Id string `json:"id"`
|
||||
Onboarding AccountOnboarding `json:"onboarding"`
|
||||
Settings AccountSettings `json:"settings"`
|
||||
}
|
||||
|
||||
// AccountExtraSettings defines model for AccountExtraSettings.
|
||||
@@ -266,9 +267,19 @@ type AccountExtraSettings struct {
|
||||
PeerApprovalEnabled bool `json:"peer_approval_enabled"`
|
||||
}
|
||||
|
||||
// AccountOnboarding defines model for AccountOnboarding.
|
||||
type AccountOnboarding struct {
|
||||
// OnboardingFlowPending Indicates whether the account onboarding flow is pending
|
||||
OnboardingFlowPending bool `json:"onboarding_flow_pending"`
|
||||
|
||||
// SignupFormPending Indicates whether the account signup form is pending
|
||||
SignupFormPending bool `json:"signup_form_pending"`
|
||||
}
|
||||
|
||||
// AccountRequest defines model for AccountRequest.
|
||||
type AccountRequest struct {
|
||||
Settings AccountSettings `json:"settings"`
|
||||
Onboarding *AccountOnboarding `json:"onboarding,omitempty"`
|
||||
Settings AccountSettings `json:"settings"`
|
||||
}
|
||||
|
||||
// AccountSettings defines model for AccountSettings.
|
||||
|
||||
@@ -59,7 +59,13 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(accountID, settings, meta)
|
||||
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(accountID, settings, meta, onboarding)
|
||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||
}
|
||||
|
||||
@@ -126,6 +132,20 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
|
||||
}
|
||||
|
||||
var onboarding *types.AccountOnboarding
|
||||
if req.Onboarding != nil {
|
||||
onboarding = &types.AccountOnboarding{
|
||||
OnboardingFlowPending: req.Onboarding.OnboardingFlowPending,
|
||||
SignupFormPending: req.Onboarding.SignupFormPending,
|
||||
}
|
||||
}
|
||||
|
||||
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -138,7 +158,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(accountID, updatedSettings, meta)
|
||||
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
@@ -167,7 +187,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta) *api.Account {
|
||||
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account {
|
||||
jwtAllowGroups := settings.JWTAllowGroups
|
||||
if jwtAllowGroups == nil {
|
||||
jwtAllowGroups = []string{}
|
||||
@@ -188,6 +208,11 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
||||
DnsDomain: &settings.DNSDomain,
|
||||
}
|
||||
|
||||
apiOnboarding := api.AccountOnboarding{
|
||||
OnboardingFlowPending: onboarding.OnboardingFlowPending,
|
||||
SignupFormPending: onboarding.SignupFormPending,
|
||||
}
|
||||
|
||||
if settings.Extra != nil {
|
||||
apiSettings.Extra = &api.AccountExtraSettings{
|
||||
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
|
||||
@@ -203,5 +228,6 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
||||
CreatedBy: meta.CreatedBy,
|
||||
Domain: meta.Domain,
|
||||
DomainCategory: meta.DomainCategory,
|
||||
Onboarding: apiOnboarding,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +54,18 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
|
||||
GetAccountMetaFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) {
|
||||
return account.GetMeta(), nil
|
||||
},
|
||||
GetAccountOnboardingFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
|
||||
return &types.AccountOnboarding{
|
||||
OnboardingFlowPending: true,
|
||||
SignupFormPending: true,
|
||||
}, nil
|
||||
},
|
||||
UpdateAccountOnboardingFunc: func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
|
||||
return &types.AccountOnboarding{
|
||||
OnboardingFlowPending: true,
|
||||
SignupFormPending: true,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
settingsManager: settingsMockManager,
|
||||
}
|
||||
@@ -117,7 +129,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true}}"),
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedSettings: api.AccountSettings{
|
||||
PeerLoginExpiration: 15552000,
|
||||
@@ -139,7 +151,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"),
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedSettings: api.AccountSettings{
|
||||
PeerLoginExpiration: 15552000,
|
||||
@@ -161,7 +173,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true}}"),
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedSettings: api.AccountSettings{
|
||||
PeerLoginExpiration: 554400,
|
||||
@@ -178,12 +190,34 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
},
|
||||
{
|
||||
name: "PutAccount OK without onboarding",
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedSettings: api.AccountSettings{
|
||||
PeerLoginExpiration: 15552000,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
GroupsPropagationEnabled: br(false),
|
||||
JwtGroupsClaimName: sr("roles"),
|
||||
JwtGroupsEnabled: br(true),
|
||||
JwtAllowGroups: &[]string{"test"},
|
||||
RegularUsersViewBlocked: true,
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
},
|
||||
{
|
||||
name: "Update account failure with high peer_login_expiration more than 180 days",
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552001,\"peer_login_expiration_enabled\": true}}"),
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552001,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedArray: false,
|
||||
},
|
||||
@@ -192,7 +226,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 3599,\"peer_login_expiration_enabled\": true}}"),
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 3599,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedArray: false,
|
||||
},
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
package testing_tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
@@ -138,7 +137,7 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
|
||||
userManager := users.NewManager(store)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
|
||||
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager)
|
||||
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -83,15 +83,12 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
|
||||
var peers []*nbpeer.Peer
|
||||
var settings *types.Settings
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peers, err = transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
|
||||
return err
|
||||
})
|
||||
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -440,11 +440,15 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
AnyTimes().
|
||||
Return(&types.Settings{}, nil)
|
||||
|
||||
settingsMockManager.
|
||||
EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
|
||||
if err != nil {
|
||||
cleanup()
|
||||
|
||||
@@ -211,7 +211,7 @@ func startServer(
|
||||
port_forwarding.NewControllerMock(),
|
||||
settingsMockManager,
|
||||
permissionsManager,
|
||||
)
|
||||
false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed creating an account manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -30,94 +30,95 @@ type MockAccountManager struct {
|
||||
GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error)
|
||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType types.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
|
||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
|
||||
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
|
||||
AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error)
|
||||
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
|
||||
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error
|
||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*types.Group, error)
|
||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error)
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error)
|
||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
|
||||
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
||||
ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error)
|
||||
SaveSetupKeyFunc func(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
|
||||
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
|
||||
SaveUserFunc func(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error)
|
||||
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error)
|
||||
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error)
|
||||
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error)
|
||||
GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error)
|
||||
GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
||||
GetDNSDomainFunc func(settings *types.Settings) string
|
||||
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||
GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
|
||||
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error
|
||||
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
||||
HasConnectedChannelFunc func(peerID string) bool
|
||||
GetExternalCacheManagerFunc func() account.ExternalCacheManager
|
||||
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error)
|
||||
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
UpdateIntegratedValidatorGroupsFunc func(ctx context.Context, accountID string, userID string, groups []string) error
|
||||
GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
||||
GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error)
|
||||
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
||||
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
GetStoreFunc func() store.Store
|
||||
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
|
||||
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
||||
|
||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
|
||||
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
|
||||
AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error)
|
||||
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
|
||||
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error
|
||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*types.Group, error)
|
||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error)
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error)
|
||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
|
||||
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
||||
ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error)
|
||||
SaveSetupKeyFunc func(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
|
||||
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
|
||||
SaveUserFunc func(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error)
|
||||
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error)
|
||||
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error)
|
||||
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error)
|
||||
GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error)
|
||||
GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
||||
GetDNSDomainFunc func(settings *types.Settings) string
|
||||
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||
GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
|
||||
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error
|
||||
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
||||
HasConnectedChannelFunc func(peerID string) bool
|
||||
GetExternalCacheManagerFunc func() account.ExternalCacheManager
|
||||
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error)
|
||||
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
UpdateIntegratedValidatorGroupsFunc func(ctx context.Context, accountID string, userID string, groups []string) error
|
||||
GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
||||
GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error)
|
||||
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
||||
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
GetStoreFunc func() store.Store
|
||||
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
|
||||
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
||||
GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error)
|
||||
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||
}
|
||||
|
||||
@@ -125,6 +126,10 @@ func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID
|
||||
// do nothing
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
|
||||
if am.DeleteSetupKeyFunc != nil {
|
||||
return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID)
|
||||
@@ -814,6 +819,22 @@ func (am *MockAccountManager) GetAccountMeta(ctx context.Context, accountID stri
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountMeta is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountOnboarding mocks GetAccountOnboarding of the AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
|
||||
if am.GetAccountOnboardingFunc != nil {
|
||||
return am.GetAccountOnboardingFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountOnboarding is not implemented")
|
||||
}
|
||||
|
||||
// UpdateAccountOnboarding mocks UpdateAccountOnboarding of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID string, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
|
||||
if am.UpdateAccountOnboardingFunc != nil {
|
||||
return am.UpdateAccountOnboardingFunc(ctx, accountID, userID, onboarding)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method UpdateAccountOnboarding is not implemented")
|
||||
}
|
||||
|
||||
// GetUserByID mocks GetUserByID of the AccountManager interface
|
||||
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
||||
if am.GetUserByIDFunc != nil {
|
||||
|
||||
@@ -778,8 +778,14 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.
|
||||
EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
}
|
||||
|
||||
func createNSStore(t *testing.T) (store.Store, error) {
|
||||
@@ -848,7 +854,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
|
||||
userID := testUserID
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
|
||||
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup
|
||||
|
||||
|
||||
@@ -391,7 +391,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if updateAccountPeers && userID != activity.SystemInitiator {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -1169,7 +1169,26 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
if am.metrics != nil {
|
||||
globalStart := time.Now()
|
||||
defer func() {
|
||||
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
|
||||
}()
|
||||
}
|
||||
|
||||
peersToUpdate := []*nbpeer.Peer{}
|
||||
for _, peer := range account.Peers {
|
||||
if !am.peersUpdateManager.HasChannel(peer.ID) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
peersToUpdate = append(peersToUpdate, peer)
|
||||
}
|
||||
|
||||
if len(peersToUpdate) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
@@ -1192,62 +1211,70 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
return
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if !am.peersUpdateManager.HasChannel(peer.ID) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
continue
|
||||
}
|
||||
extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, peer := range peersToUpdate {
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(p *nbpeer.Peer) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
postureChecks, err := am.getPeerPostureChecks(account, p.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
|
||||
am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
|
||||
|
||||
extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
start = time.Now()
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting)
|
||||
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
wg.Wait()
|
||||
if am.metrics != nil {
|
||||
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(start))
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{})
|
||||
lock := mu.(*sync.Mutex)
|
||||
|
||||
log.WithContext(ctx).Debugf("try to BufferUpdateAccountPeers for account %s", accountID)
|
||||
|
||||
if !lock.TryLock() {
|
||||
log.WithContext(ctx).Debugf("BufferUpdateAccountPeers for an account %s locked - returning", accountID)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
|
||||
lock.Unlock()
|
||||
// time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
|
||||
defer lock.Unlock()
|
||||
log.WithContext(ctx).Debugf("BufferUpdateAccountPeers for an account %s - in progress", accountID)
|
||||
tn := time.Now()
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
log.WithContext(ctx).Debugf("BufferUpdateAccountPeers for an account %s - took %dms", accountID, time.Since(tn).Milliseconds())
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -480,7 +480,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account.Users[someUser] = &types.User{
|
||||
Id: someUser,
|
||||
Role: types.UserRoleUser,
|
||||
@@ -667,7 +667,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account.Users[someUser] = &types.User{
|
||||
Id: someUser,
|
||||
Role: testCase.role,
|
||||
@@ -737,7 +737,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
|
||||
adminUser := "account_creator"
|
||||
regularUser := "regular_user"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account.Users[regularUser] = &types.User{
|
||||
Id: regularUser,
|
||||
Role: types.UserRoleUser,
|
||||
@@ -1267,7 +1267,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1340,9 +1340,14 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.
|
||||
EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1477,7 +1482,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1544,9 +1549,14 @@ func Test_LoginPeer(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.
|
||||
EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -2052,7 +2062,7 @@ func Test_DeletePeer(t *testing.T) {
|
||||
// account with an admin and a regular user
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account.Peers = map[string]*nbpeer.Peer{
|
||||
"peer1": {
|
||||
ID: "peer1",
|
||||
|
||||
@@ -106,7 +106,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
|
||||
Role: types.UserRoleUser,
|
||||
}
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
|
||||
account.Users[admin.Id] = admin
|
||||
account.Users[user.Id] = user
|
||||
|
||||
|
||||
@@ -1284,7 +1284,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
}
|
||||
|
||||
func createRouterStore(t *testing.T) (store.Store, error) {
|
||||
@@ -1305,7 +1305,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou
|
||||
accountID := "testingAcc"
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -90,6 +90,11 @@ func NewAccountNotFoundError(accountKey string) error {
|
||||
return Errorf(NotFound, "account not found: %s", accountKey)
|
||||
}
|
||||
|
||||
// NewAccountOnboardingNotFoundError creates a new Error with NotFound type for a missing account onboarding
|
||||
func NewAccountOnboardingNotFoundError(accountKey string) error {
|
||||
return Errorf(NotFound, "account onboarding not found: %s", accountKey)
|
||||
}
|
||||
|
||||
// NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account
|
||||
func NewPeerNotPartOfAccountError() error {
|
||||
return Errorf(PermissionDenied, "peer is not part of this account")
|
||||
|
||||
@@ -99,7 +99,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{},
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migrate: %w", err)
|
||||
@@ -725,6 +725,32 @@ func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStren
|
||||
return &accountMeta, nil
|
||||
}
|
||||
|
||||
// GetAccountOnboarding retrieves the onboarding information for a specific account.
|
||||
func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error) {
|
||||
var accountOnboarding types.AccountOnboarding
|
||||
result := s.db.Model(&accountOnboarding).First(&accountOnboarding, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAccountOnboardingNotFoundError(accountID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting account onboarding %s from the store: %s", accountID, result.Error)
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
return &accountOnboarding, nil
|
||||
}
|
||||
|
||||
// SaveAccountOnboarding updates the onboarding information for a specific account.
|
||||
func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error {
|
||||
result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(onboarding)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error)
|
||||
return status.Errorf(status.Internal, "error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
@@ -1184,7 +1210,7 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data
|
||||
for _, account := range fileStore.GetAllAccounts(ctx) {
|
||||
_, err = account.GetGroupAll()
|
||||
if err != nil {
|
||||
if err := account.AddAllGroup(); err != nil {
|
||||
if err := account.AddAllGroup(false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -353,9 +353,16 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
|
||||
}
|
||||
|
||||
o, err := store.GetAccountOnboarding(context.Background(), account.Id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, o.AccountID, account.Id)
|
||||
|
||||
err = store.DeleteAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.GetAccountOnboarding(context.Background(), account.Id)
|
||||
require.Error(t, err, "expecting error after removing DeleteAccount when getting onboarding")
|
||||
|
||||
if len(store.GetAllAccounts(context.Background())) != 0 {
|
||||
t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()")
|
||||
}
|
||||
@@ -413,12 +420,21 @@ func Test_GetAccount(t *testing.T) {
|
||||
account, err := store.GetAccount(context.Background(), id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, account.Id, "account id should match")
|
||||
require.Equal(t, false, account.Onboarding.OnboardingFlowPending)
|
||||
|
||||
id = "9439-34653001fc3b-bf1c8084-ba50-4ce7"
|
||||
|
||||
account, err = store.GetAccount(context.Background(), id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, account.Id, "account id should match")
|
||||
require.Equal(t, true, account.Onboarding.OnboardingFlowPending)
|
||||
|
||||
_, err = store.GetAccount(context.Background(), "non-existing-account")
|
||||
assert.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2042,9 +2058,10 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
|
||||
},
|
||||
Onboarding: types.AccountOnboarding{SignupFormPending: true, OnboardingFlowPending: true},
|
||||
}
|
||||
|
||||
if err := acc.AddAllGroup(); err != nil {
|
||||
if err := acc.AddAllGroup(false); err != nil {
|
||||
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
|
||||
}
|
||||
return acc
|
||||
@@ -3386,6 +3403,63 @@ func TestSqlStore_GetAccountMeta(t *testing.T) {
|
||||
require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC())
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountOnboarding(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "9439-34653001fc3b-bf1c8084-ba50-4ce7"
|
||||
a, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Onboarding: %+v", a.Onboarding)
|
||||
err = store.SaveAccount(context.Background(), a)
|
||||
require.NoError(t, err)
|
||||
onboarding, err := store.GetAccountOnboarding(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, onboarding)
|
||||
require.Equal(t, accountID, onboarding.AccountID)
|
||||
require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), onboarding.CreatedAt.UTC())
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveAccountOnboarding(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
t.Run("New onboarding should be saved correctly", func(t *testing.T) {
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
onboarding := &types.AccountOnboarding{
|
||||
AccountID: accountID,
|
||||
SignupFormPending: true,
|
||||
OnboardingFlowPending: true,
|
||||
}
|
||||
|
||||
err = store.SaveAccountOnboarding(context.Background(), onboarding)
|
||||
require.NoError(t, err)
|
||||
|
||||
savedOnboarding, err := store.GetAccountOnboarding(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, onboarding.SignupFormPending, savedOnboarding.SignupFormPending)
|
||||
require.Equal(t, onboarding.OnboardingFlowPending, savedOnboarding.OnboardingFlowPending)
|
||||
})
|
||||
|
||||
t.Run("Existing onboarding should be updated correctly", func(t *testing.T) {
|
||||
accountID := "9439-34653001fc3b-bf1c8084-ba50-4ce7"
|
||||
onboarding, err := store.GetAccountOnboarding(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
onboarding.OnboardingFlowPending = !onboarding.OnboardingFlowPending
|
||||
onboarding.SignupFormPending = !onboarding.SignupFormPending
|
||||
|
||||
err = store.SaveAccountOnboarding(context.Background(), onboarding)
|
||||
require.NoError(t, err)
|
||||
|
||||
savedOnboarding, err := store.GetAccountOnboarding(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, onboarding.SignupFormPending, savedOnboarding.SignupFormPending)
|
||||
require.Equal(t, onboarding.OnboardingFlowPending, savedOnboarding.OnboardingFlowPending)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAnyAccountID(t *testing.T) {
|
||||
t.Run("should return account ID when accounts exist", func(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
|
||||
@@ -52,6 +52,7 @@ type Store interface {
|
||||
GetAllAccounts(ctx context.Context) []*types.Account
|
||||
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
|
||||
GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error)
|
||||
GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error)
|
||||
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
|
||||
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
||||
GetAccountByUser(ctx context.Context, userID string) (*types.Account, error)
|
||||
@@ -74,6 +75,7 @@ type Store interface {
|
||||
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
|
||||
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error
|
||||
CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error)
|
||||
SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error
|
||||
|
||||
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
|
||||
@@ -391,7 +393,7 @@ func addAllGroupToAccount(ctx context.Context, store Store) error {
|
||||
|
||||
_, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
if err := account.AddAllGroup(); err != nil {
|
||||
if err := account.AddAllGroup(false); err != nil {
|
||||
return err
|
||||
}
|
||||
shouldSave = true
|
||||
|
||||
@@ -18,6 +18,10 @@ type UpdateChannelMetrics struct {
|
||||
getAllConnectedPeersDurationMicro metric.Int64Histogram
|
||||
getAllConnectedPeers metric.Int64Histogram
|
||||
hasChannelDurationMicro metric.Int64Histogram
|
||||
calcPostureChecksDurationMicro metric.Int64Histogram
|
||||
calcPeerNetworkMapDurationMs metric.Int64Histogram
|
||||
mergeNetworkMapDurationMicro metric.Int64Histogram
|
||||
toSyncResponseDurationMicro metric.Int64Histogram
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
@@ -89,6 +93,38 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
|
||||
return nil, err
|
||||
}
|
||||
|
||||
calcPostureChecksDurationMicro, err := meter.Int64Histogram("management.updatechannel.calc.posturechecks.duration.micro",
|
||||
metric.WithUnit("microseconds"),
|
||||
metric.WithDescription("Duration of how long it takes to get the posture checks for a peer"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
calcPeerNetworkMapDurationMs, err := meter.Int64Histogram("management.updatechannel.calc.networkmap.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("Duration of how long it takes to calculate the network map for a peer"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mergeNetworkMapDurationMicro, err := meter.Int64Histogram("management.updatechannel.merge.networkmap.duration.micro",
|
||||
metric.WithUnit("microseconds"),
|
||||
metric.WithDescription("Duration of how long it takes to merge the network maps for a peer"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toSyncResponseDurationMicro, err := meter.Int64Histogram("management.updatechannel.tosyncresponse.duration.micro",
|
||||
metric.WithUnit("microseconds"),
|
||||
metric.WithDescription("Duration of how long it takes to convert the network map to sync response"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &UpdateChannelMetrics{
|
||||
createChannelDurationMicro: createChannelDurationMicro,
|
||||
closeChannelDurationMicro: closeChannelDurationMicro,
|
||||
@@ -98,6 +134,10 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
|
||||
getAllConnectedPeersDurationMicro: getAllConnectedPeersDurationMicro,
|
||||
getAllConnectedPeers: getAllConnectedPeers,
|
||||
hasChannelDurationMicro: hasChannelDurationMicro,
|
||||
calcPostureChecksDurationMicro: calcPostureChecksDurationMicro,
|
||||
calcPeerNetworkMapDurationMs: calcPeerNetworkMapDurationMs,
|
||||
mergeNetworkMapDurationMicro: mergeNetworkMapDurationMicro,
|
||||
toSyncResponseDurationMicro: toSyncResponseDurationMicro,
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
@@ -137,3 +177,19 @@ func (metrics *UpdateChannelMetrics) CountGetAllConnectedPeersDuration(duration
|
||||
func (metrics *UpdateChannelMetrics) CountHasChannelDuration(duration time.Duration) {
|
||||
metrics.hasChannelDurationMicro.Record(metrics.ctx, duration.Microseconds())
|
||||
}
|
||||
|
||||
func (metrics *UpdateChannelMetrics) CountCalcPostureChecksDuration(duration time.Duration) {
|
||||
metrics.calcPostureChecksDurationMicro.Record(metrics.ctx, duration.Microseconds())
|
||||
}
|
||||
|
||||
func (metrics *UpdateChannelMetrics) CountCalcPeerNetworkMapDuration(duration time.Duration) {
|
||||
metrics.calcPeerNetworkMapDurationMs.Record(metrics.ctx, duration.Milliseconds())
|
||||
}
|
||||
|
||||
func (metrics *UpdateChannelMetrics) CountMergeNetworkMapDuration(duration time.Duration) {
|
||||
metrics.mergeNetworkMapDurationMicro.Record(metrics.ctx, duration.Microseconds())
|
||||
}
|
||||
|
||||
func (metrics *UpdateChannelMetrics) CountToSyncResponseDuration(duration time.Duration) {
|
||||
metrics.toSyncResponseDurationMicro.Record(metrics.ctx, duration.Microseconds())
|
||||
}
|
||||
|
||||
4
management/server/testdata/store.sql
vendored
4
management/server/testdata/store.sql
vendored
@@ -1,4 +1,5 @@
|
||||
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
|
||||
CREATE TABLE `account_onboardings` (`account_id` text, `created_at` datetime,`updated_at` datetime, `onboarding_flow_pending` numeric, `signup_form_pending` numeric, PRIMARY KEY (`account_id`));
|
||||
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||
@@ -38,7 +39,8 @@ CREATE INDEX `idx_networks_id` ON `networks`(`id`);
|
||||
CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`);
|
||||
|
||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
||||
INSERT INTO accounts VALUES('9439-34653001fc3b-bf1c8084-ba50-4ce7','90d6-0242ac120003-edafee4e-63fb-11ec','2024-10-02 16:01:38.210000+02:00','test2.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO account_onboardings VALUES('9439-34653001fc3b-bf1c8084-ba50-4ce7','2024-10-02 16:01:38.210000+02:00','2021-08-19 20:46:20.005936822+02:00',1,0);INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
||||
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0);
|
||||
INSERT INTO users VALUES('a23efe53-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','owner',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,'');
|
||||
|
||||
@@ -82,11 +82,11 @@ type Account struct {
|
||||
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
|
||||
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
|
||||
// Settings is a dictionary of Account settings
|
||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||
|
||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
|
||||
NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"`
|
||||
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
|
||||
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
|
||||
}
|
||||
|
||||
// Subclass used in gorm to only load network and not whole account
|
||||
@@ -104,6 +104,20 @@ type AccountSettings struct {
|
||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||
}
|
||||
|
||||
type AccountOnboarding struct {
|
||||
AccountID string `gorm:"primaryKey"`
|
||||
OnboardingFlowPending bool
|
||||
SignupFormPending bool
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// IsEqual compares two AccountOnboarding objects and returns true if they are equal
|
||||
func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
|
||||
return o.OnboardingFlowPending == onboarding.OnboardingFlowPending &&
|
||||
o.SignupFormPending == onboarding.SignupFormPending
|
||||
}
|
||||
|
||||
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||
// from the ACL peers that have distribution groups associated with the peer ID.
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
@@ -866,6 +880,7 @@ func (a *Account) Copy() *Account {
|
||||
Networks: nets,
|
||||
NetworkRouters: networkRouters,
|
||||
NetworkResources: networkResources,
|
||||
Onboarding: a.Onboarding,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1546,7 +1561,7 @@ func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[st
|
||||
}
|
||||
|
||||
// AddAllGroup to account object if it doesn't exist
|
||||
func (a *Account) AddAllGroup() error {
|
||||
func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
|
||||
if len(a.Groups) == 0 {
|
||||
allGroup := &Group{
|
||||
ID: xid.New().String(),
|
||||
@@ -1558,6 +1573,10 @@ func (a *Account) AddAllGroup() error {
|
||||
}
|
||||
a.Groups = map[string]*Group{allGroup.ID: allGroup}
|
||||
|
||||
if disableDefaultPolicy {
|
||||
return nil
|
||||
}
|
||||
|
||||
id := xid.New().String()
|
||||
|
||||
defaultPolicy := &Policy{
|
||||
|
||||
@@ -53,6 +53,9 @@ type Config struct {
|
||||
StoreConfig StoreConfig
|
||||
|
||||
ReverseProxy ReverseProxy
|
||||
|
||||
// disable default all-to-all policy
|
||||
DisableDefaultPolicy bool
|
||||
}
|
||||
|
||||
// GetAuthAudiences returns the audience from the http config and device authorization flow config
|
||||
|
||||
@@ -56,7 +56,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = s.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -103,7 +103,7 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockTargetUserId] = &types.User{
|
||||
Id: mockTargetUserId,
|
||||
IsServiceUser: false,
|
||||
@@ -131,7 +131,7 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockTargetUserId] = &types.User{
|
||||
Id: mockTargetUserId,
|
||||
IsServiceUser: true,
|
||||
@@ -163,7 +163,7 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -188,7 +188,7 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -213,7 +213,7 @@ func TestUser_DeletePAT(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
@@ -256,7 +256,7 @@ func TestUser_GetPAT(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
@@ -296,7 +296,7 @@ func TestUser_GetAllPATs(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
@@ -406,7 +406,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -453,7 +453,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -501,7 +501,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -532,7 +532,7 @@ func TestUser_InviteNewUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -639,7 +639,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockServiceUserID] = tt.serviceUser
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
@@ -678,7 +678,7 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -705,7 +705,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &types.User{
|
||||
@@ -792,7 +792,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &types.User{
|
||||
@@ -952,7 +952,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -988,7 +988,7 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users["normal_user1"] = types.NewRegularUser("normal_user1")
|
||||
account.Users["normal_user2"] = types.NewRegularUser("normal_user2")
|
||||
|
||||
@@ -1030,7 +1030,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
externalUser := &types.User{
|
||||
Id: "externalUser",
|
||||
Role: types.UserRoleUser,
|
||||
@@ -1098,7 +1098,7 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockServiceUserID] = &types.User{
|
||||
Id: mockServiceUserID,
|
||||
Role: "user",
|
||||
@@ -1132,7 +1132,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockServiceUserID] = &types.User{
|
||||
Id: mockServiceUserID,
|
||||
Role: "user",
|
||||
@@ -1499,7 +1499,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "")
|
||||
account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "", false)
|
||||
targetId := "user2"
|
||||
account1.Users[targetId] = &types.User{
|
||||
Id: targetId,
|
||||
@@ -1508,7 +1508,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, s.SaveAccount(context.Background(), account1))
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "")
|
||||
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "", false)
|
||||
require.NoError(t, s.SaveAccount(context.Background(), account2))
|
||||
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
@@ -1535,7 +1535,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "")
|
||||
account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "", false)
|
||||
account1.Settings.RegularUsersViewBlocked = false
|
||||
account1.Users["blocked-user"] = &types.User{
|
||||
Id: "blocked-user",
|
||||
@@ -1557,7 +1557,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, store.SaveAccount(context.Background(), account1))
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "")
|
||||
account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "", false)
|
||||
account2.Users["settings-blocked-user"] = &types.User{
|
||||
Id: "settings-blocked-user",
|
||||
Role: types.UserRoleUser,
|
||||
|
||||
@@ -130,7 +130,7 @@ repo_gpgcheck=1
|
||||
EOF
|
||||
}
|
||||
|
||||
add_aur_repo() {
|
||||
install_aur_package() {
|
||||
INSTALL_PKGS="git base-devel go"
|
||||
REMOVE_PKGS=""
|
||||
|
||||
@@ -154,8 +154,10 @@ add_aur_repo() {
|
||||
cd netbird-ui && makepkg -sri --noconfirm
|
||||
fi
|
||||
|
||||
# Clean up the installed packages
|
||||
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
|
||||
if [ -n "$REMOVE_PKGS" ]; then
|
||||
# Clean up the installed packages
|
||||
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
|
||||
fi
|
||||
}
|
||||
|
||||
prepare_tun_module() {
|
||||
@@ -277,7 +279,9 @@ install_netbird() {
|
||||
;;
|
||||
pacman)
|
||||
${SUDO} pacman -Syy
|
||||
add_aur_repo
|
||||
install_aur_package
|
||||
# in-line with the docs at https://wiki.archlinux.org/title/Netbird
|
||||
${SUDO} systemctl enable --now netbird@main.service
|
||||
;;
|
||||
pkg)
|
||||
# Check if the package is already installed
|
||||
@@ -494,4 +498,4 @@ case "$UPDATE_FLAG" in
|
||||
;;
|
||||
*)
|
||||
install_netbird
|
||||
esac
|
||||
esac
|
||||
|
||||
Reference in New Issue
Block a user