mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-05 08:36:37 +00:00
Merge branch 'main' into handle-existing-domain-user
# Conflicts: # management/server/account.go # management/server/account_test.go
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
@@ -39,6 +39,7 @@ type Manager interface {
|
||||
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
|
||||
GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
||||
GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error)
|
||||
GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error)
|
||||
AccountExists(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||
GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
@@ -50,6 +51,7 @@ type Manager interface {
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
|
||||
AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
@@ -61,8 +63,10 @@ type Manager interface {
|
||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
|
||||
SaveGroup(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
|
||||
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group, create bool) error
|
||||
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
|
||||
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
|
||||
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error
|
||||
UpdateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error
|
||||
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
|
||||
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
@@ -88,7 +92,8 @@ type Manager interface {
|
||||
GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error)
|
||||
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error
|
||||
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error)
|
||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
GetAllConnectedPeers() (map[string]struct{}, error)
|
||||
@@ -99,7 +104,7 @@ type Manager interface {
|
||||
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManager() idp.Manager
|
||||
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
|
||||
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
@@ -110,10 +115,11 @@ type Manager interface {
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
||||
UpdateAccountPeers(ctx context.Context, accountID string)
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string)
|
||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||
GetStore() store.Store
|
||||
CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error)
|
||||
GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
|
||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
@@ -14,7 +15,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/cache"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
@@ -373,7 +374,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, testCase := range tt {
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io", false)
|
||||
account.UpdateSettings(&testCase.accountSettings)
|
||||
account.Network = network
|
||||
account.Peers = testCase.peers
|
||||
@@ -398,7 +399,7 @@ func TestNewAccount(t *testing.T) {
|
||||
domain := "netbird.io"
|
||||
userId := "account_creator"
|
||||
accountID := "account_id"
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain, false)
|
||||
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
|
||||
}
|
||||
|
||||
@@ -640,7 +641,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain)
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain, false)
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
|
||||
@@ -782,7 +783,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID)
|
||||
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthNone, accountID)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists, "expected to get existing account after creation using userid")
|
||||
|
||||
@@ -793,7 +794,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
}
|
||||
|
||||
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) {
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -899,11 +900,11 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
|
||||
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
|
||||
}
|
||||
|
||||
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, "service-user-1")
|
||||
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, "service-user-1")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 0)
|
||||
|
||||
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, userId)
|
||||
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, userId)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 0)
|
||||
}
|
||||
@@ -1159,7 +1160,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
Name: "GroupA",
|
||||
Peers: []string{},
|
||||
}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
|
||||
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1194,7 +1195,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
}()
|
||||
|
||||
group.Peers = []string{peer1.ID, peer2.ID, peer3.ID}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
|
||||
if err := manager.UpdateGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1208,6 +1209,14 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
// Ensure that we do not receive an update message before the policy is deleted
|
||||
time.Sleep(time.Second)
|
||||
select {
|
||||
case <-updMsg:
|
||||
t.Logf("received addPeer update message before policy deletion")
|
||||
default:
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -1232,11 +1241,12 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
AccountID: account.Id,
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
|
||||
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1284,7 +1294,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
|
||||
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1335,11 +1345,11 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
}, true)
|
||||
})
|
||||
|
||||
require.NoError(t, err, "failed to save group")
|
||||
|
||||
@@ -1664,9 +1674,10 @@ func TestAccount_Copy(t *testing.T) {
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"group1": {
|
||||
ID: "group1",
|
||||
Peers: []string{"peer1"},
|
||||
Resources: []types.Resource{},
|
||||
ID: "group1",
|
||||
Peers: []string{"peer1"},
|
||||
Resources: []types.Resource{},
|
||||
GroupPeers: []types.GroupPeer{},
|
||||
},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
@@ -1775,7 +1786,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
assert.NotNil(t, settings)
|
||||
@@ -1805,9 +1816,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
account, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
@@ -1825,11 +1837,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
// disable expiration first
|
||||
update := peer.Copy()
|
||||
update.LoginExpirationEnabled = false
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
|
||||
_, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
|
||||
require.NoError(t, err, "unable to update peer")
|
||||
// enabling expiration should trigger the routine
|
||||
update.LoginExpirationEnabled = true
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
|
||||
_, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
|
||||
require.NoError(t, err, "unable to update peer")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1856,15 +1868,13 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
wg.Add(1)
|
||||
manager.peerLoginExpiry = &MockScheduler{
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {
|
||||
wg.Done()
|
||||
},
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
wg.Done()
|
||||
},
|
||||
@@ -1919,9 +1929,10 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
},
|
||||
}
|
||||
// enabling PeerLoginExpirationEnabled should trigger the expiration job
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
@@ -1935,6 +1946,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
failed = waitTimeout(wg, time.Second)
|
||||
@@ -1950,15 +1962,16 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
||||
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
assert.False(t, settings.PeerLoginExpirationEnabled)
|
||||
@@ -1967,12 +1980,14 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Second,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour * 24 * 181,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
|
||||
}
|
||||
@@ -2604,6 +2619,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", "postgres")
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
@@ -2611,11 +2627,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
account := &types.Account{
|
||||
Id: "accountID",
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
|
||||
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
|
||||
"peer3": {ID: "peer3", Key: "key3", UserID: "user1"},
|
||||
"peer4": {ID: "peer4", Key: "key4", UserID: "user2"},
|
||||
"peer5": {ID: "peer5", Key: "key5", UserID: "user2"},
|
||||
"peer1": {ID: "peer1", Key: "key1", UserID: "user1", IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"},
|
||||
"peer2": {ID: "peer2", Key: "key2", UserID: "user1", IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"},
|
||||
"peer3": {ID: "peer3", Key: "key3", UserID: "user1", IP: net.IP{3, 3, 3, 3}, DNSLabel: "peer3.domain.test"},
|
||||
"peer4": {ID: "peer4", Key: "key4", UserID: "user2", IP: net.IP{4, 4, 4, 4}, DNSLabel: "peer4.domain.test"},
|
||||
"peer5": {ID: "peer5", Key: "key5", UserID: "user2", IP: net.IP{5, 5, 5, 5}, DNSLabel: "peer5.domain.test"},
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}},
|
||||
@@ -2639,7 +2655,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
|
||||
})
|
||||
@@ -2653,7 +2669,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err := manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Empty(t, user.AutoGroups, "auto groups must be empty")
|
||||
})
|
||||
@@ -2667,18 +2683,18 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err := manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
|
||||
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
|
||||
account.Users["user1"].AutoGroups = []string{"group1"}
|
||||
assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"]))
|
||||
assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"]))
|
||||
|
||||
claims := nbcontext.UserAuth{
|
||||
UserId: "user1",
|
||||
@@ -2688,11 +2704,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
@@ -2706,7 +2722,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
|
||||
})
|
||||
@@ -2720,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
|
||||
})
|
||||
@@ -2734,11 +2750,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID")
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, "accountID")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, groups, 3, "new group3 should be added")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
||||
})
|
||||
@@ -2752,7 +2768,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
|
||||
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
|
||||
@@ -2767,7 +2783,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
|
||||
})
|
||||
@@ -2875,7 +2891,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3135,11 +3151,11 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
minMsPerOpCICD float64
|
||||
maxMsPerOpCICD float64
|
||||
}{
|
||||
{"Small", 50, 5, 7, 20, 10, 80},
|
||||
{"Small", 50, 5, 7, 20, 5, 80},
|
||||
{"Medium", 500, 100, 5, 40, 30, 140},
|
||||
{"Large", 5000, 200, 80, 120, 140, 390},
|
||||
{"Small single", 50, 10, 7, 20, 10, 80},
|
||||
{"Medium single", 500, 10, 5, 40, 20, 85},
|
||||
{"Small single", 50, 10, 7, 20, 6, 80},
|
||||
{"Medium single", 500, 10, 5, 40, 15, 85},
|
||||
{"Large 5", 5000, 15, 80, 120, 80, 200},
|
||||
}
|
||||
|
||||
@@ -3198,7 +3214,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateAccountByPrivateDomain(t *testing.T) {
|
||||
func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -3209,9 +3225,10 @@ func Test_CreateAccountByPrivateDomain(t *testing.T) {
|
||||
initiatorId := "test-user"
|
||||
domain := "example.com"
|
||||
|
||||
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
account, created, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.True(t, created)
|
||||
assert.False(t, account.IsDomainPrimaryAccount)
|
||||
assert.Equal(t, domain, account.Domain)
|
||||
assert.Equal(t, types.PrivateCategory, account.DomainCategory)
|
||||
@@ -3220,9 +3237,25 @@ func Test_CreateAccountByPrivateDomain(t *testing.T) {
|
||||
assert.Equal(t, 0, len(account.Users))
|
||||
assert.Equal(t, 0, len(account.SetupKeys))
|
||||
|
||||
// retry should fail
|
||||
_, err = manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.Error(t, err)
|
||||
// should return a new account because the previous one is not primary
|
||||
account2, created2, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.True(t, created2)
|
||||
assert.False(t, account2.IsDomainPrimaryAccount)
|
||||
assert.Equal(t, domain, account2.Domain)
|
||||
assert.Equal(t, types.PrivateCategory, account2.DomainCategory)
|
||||
assert.Equal(t, initiatorId, account2.CreatedBy)
|
||||
assert.Equal(t, 1, len(account2.Groups))
|
||||
assert.Equal(t, 0, len(account2.Users))
|
||||
assert.Equal(t, 0, len(account2.SetupKeys))
|
||||
|
||||
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, account.IsDomainPrimaryAccount)
|
||||
|
||||
_, err = manager.UpdateToPrimaryAccount(ctx, account2.Id)
|
||||
assert.Error(t, err, "should not be able to update a second account to primary")
|
||||
}
|
||||
|
||||
func Test_UpdateToPrimaryAccount(t *testing.T) {
|
||||
@@ -3236,14 +3269,21 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
|
||||
initiatorId := "test-user"
|
||||
domain := "example.com"
|
||||
|
||||
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
account, created, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, created)
|
||||
assert.False(t, account.IsDomainPrimaryAccount)
|
||||
assert.Equal(t, domain, account.Domain)
|
||||
|
||||
// retry should fail
|
||||
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, account.IsDomainPrimaryAccount)
|
||||
|
||||
account2, created2, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, created2)
|
||||
assert.True(t, account.IsDomainPrimaryAccount)
|
||||
assert.Equal(t, account.Id, account2.Id)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
|
||||
@@ -3296,6 +3336,123 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
initiatorId := "test-user"
|
||||
domain := "example.com"
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
|
||||
err = manager.Store.AddPeerToAccount(ctx, peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
|
||||
err = manager.Store.AddPeerToAccount(ctx, peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("should skip propagation when the user has no groups", func(t *testing.T) {
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
})
|
||||
|
||||
t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) {
|
||||
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
|
||||
require.NoError(t, manager.Store.CreateGroup(ctx, group1))
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
|
||||
require.NoError(t, err)
|
||||
|
||||
user.AutoGroups = append(user.AutoGroups, group1.ID)
|
||||
require.NoError(t, manager.Store.SaveUser(ctx, user))
|
||||
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
|
||||
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, group1.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, group.Peers, 2)
|
||||
assert.Contains(t, group.Peers, "peer1")
|
||||
assert.Contains(t, group.Peers, "peer2")
|
||||
})
|
||||
|
||||
t.Run("should update membership and account peers for used groups", func(t *testing.T) {
|
||||
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
|
||||
require.NoError(t, manager.Store.CreateGroup(ctx, group2))
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
|
||||
require.NoError(t, err)
|
||||
|
||||
user.AutoGroups = append(user.AutoGroups, group2.ID)
|
||||
require.NoError(t, manager.Store.SaveUser(ctx, user))
|
||||
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, initiatorId, &types.Policy{
|
||||
Name: "Group1 Policy",
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"group1"},
|
||||
Destinations: []string{"group2"},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, groupsUpdated)
|
||||
assert.True(t, groupChangesAffectPeers)
|
||||
|
||||
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
|
||||
require.NoError(t, err)
|
||||
for _, group := range groups {
|
||||
assert.Len(t, group.Peers, 2)
|
||||
assert.Contains(t, group.Peers, "peer1")
|
||||
assert.Contains(t, group.Peers, "peer2")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should not update membership or account peers when no changes", func(t *testing.T) {
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
})
|
||||
|
||||
t.Run("should not remove peers when groups are removed from user", func(t *testing.T) {
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
|
||||
require.NoError(t, err)
|
||||
|
||||
user.AutoGroups = []string{"group1"}
|
||||
require.NoError(t, manager.Store.SaveUser(ctx, user))
|
||||
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
|
||||
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
|
||||
require.NoError(t, err)
|
||||
for _, group := range groups {
|
||||
assert.Len(t, group.Peers, 2)
|
||||
assert.Contains(t, group.Peers, "peer1")
|
||||
assert.Contains(t, group.Peers, "peer2")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_AddNewUserToDomainAccount(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -3339,3 +3496,141 @@ func TestDefaultAccountManager_AddNewUserToDomainAccount(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("should return account onboarding when onboarding exist", func(t *testing.T) {
|
||||
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, onboarding)
|
||||
assert.Equal(t, account.Id, onboarding.AccountID)
|
||||
assert.Equal(t, true, onboarding.OnboardingFlowPending)
|
||||
assert.Equal(t, true, onboarding.SignupFormPending)
|
||||
if onboarding.UpdatedAt.IsZero() {
|
||||
t.Errorf("Onboarding was not retrieved from the store")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should return account onboarding when onboard don't exist", func(t *testing.T) {
|
||||
account.Id = "with-zero-onboarding"
|
||||
account.Onboarding = types.AccountOnboarding{}
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, onboarding)
|
||||
_, err = manager.Store.GetAccountOnboarding(context.Background(), account.Id)
|
||||
require.Error(t, err, "should return error when onboarding is not set")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
onboarding := &types.AccountOnboarding{
|
||||
OnboardingFlowPending: true,
|
||||
SignupFormPending: true,
|
||||
}
|
||||
|
||||
t.Run("update onboarding with no change", func(t *testing.T) {
|
||||
updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending)
|
||||
assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending)
|
||||
if updated.UpdatedAt.IsZero() {
|
||||
t.Errorf("Onboarding was updated in the store")
|
||||
}
|
||||
})
|
||||
|
||||
onboarding.OnboardingFlowPending = false
|
||||
onboarding.SignupFormPending = false
|
||||
|
||||
t.Run("update onboarding", func(t *testing.T) {
|
||||
updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending)
|
||||
assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending)
|
||||
})
|
||||
|
||||
t.Run("update onboarding with no onboarding", func(t *testing.T) {
|
||||
_, err = manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, nil)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key1, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
key2, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
Key: key1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer1")
|
||||
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
Key: key2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer2")
|
||||
|
||||
t.Run("update peer IP successfully", func(t *testing.T) {
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get account")
|
||||
|
||||
newIP, err := types.AllocatePeerIP(account.Network.Net, []net.IP{peer1.IP, peer2.IP})
|
||||
require.NoError(t, err, "unable to allocate new IP")
|
||||
|
||||
newAddr := netip.MustParseAddr(newIP.String())
|
||||
err = manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, newAddr)
|
||||
require.NoError(t, err, "unable to update peer IP")
|
||||
|
||||
updatedPeer, err := manager.GetPeer(context.Background(), accountID, peer1.ID, userID)
|
||||
require.NoError(t, err, "unable to get updated peer")
|
||||
assert.Equal(t, newIP.String(), updatedPeer.IP.String(), "peer IP should be updated")
|
||||
})
|
||||
|
||||
t.Run("update peer IP with same IP should be no-op", func(t *testing.T) {
|
||||
currentAddr := netip.MustParseAddr(peer1.IP.String())
|
||||
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, currentAddr)
|
||||
require.NoError(t, err, "updating with same IP should not error")
|
||||
})
|
||||
|
||||
t.Run("update peer IP with collision should fail", func(t *testing.T) {
|
||||
peer2Addr := netip.MustParseAddr(peer2.IP.String())
|
||||
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, peer2Addr)
|
||||
require.Error(t, err, "should fail when IP is already assigned")
|
||||
assert.Contains(t, err.Error(), "already assigned", "error should mention IP collision")
|
||||
})
|
||||
|
||||
t.Run("update peer IP outside network range should fail", func(t *testing.T) {
|
||||
invalidAddr := netip.MustParseAddr("192.168.1.100")
|
||||
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, invalidAddr)
|
||||
require.Error(t, err, "should fail when IP is outside network range")
|
||||
assert.Contains(t, err.Error(), "not within the account network range", "error should mention network range")
|
||||
})
|
||||
|
||||
t.Run("update peer IP with invalid peer ID should fail", func(t *testing.T) {
|
||||
newAddr := netip.MustParseAddr("100.64.0.101")
|
||||
err := manager.UpdatePeerIP(context.Background(), accountID, userID, "invalid-peer-id", newAddr)
|
||||
require.Error(t, err, "should fail with invalid peer ID")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -171,6 +171,14 @@ const (
|
||||
ResourceRemovedFromGroup Activity = 83
|
||||
|
||||
AccountDNSDomainUpdated Activity = 84
|
||||
|
||||
AccountLazyConnectionEnabled Activity = 85
|
||||
AccountLazyConnectionDisabled Activity = 86
|
||||
|
||||
AccountNetworkRangeUpdated Activity = 87
|
||||
PeerIPUpdated Activity = 88
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
|
||||
var activityMap = map[Activity]Code{
|
||||
@@ -179,6 +187,7 @@ var activityMap = map[Activity]Code{
|
||||
UserJoined: {"User joined", "user.join"},
|
||||
UserInvited: {"User invited", "user.invite"},
|
||||
AccountCreated: {"Account created", "account.create"},
|
||||
AccountDeleted: {"Account deleted", "account.delete"},
|
||||
PeerRemovedByUser: {"Peer deleted", "user.peer.delete"},
|
||||
RuleAdded: {"Rule added", "rule.add"},
|
||||
RuleUpdated: {"Rule updated", "rule.update"},
|
||||
@@ -268,6 +277,13 @@ var activityMap = map[Activity]Code{
|
||||
ResourceRemovedFromGroup: {"Resource removed from group", "resource.group.delete"},
|
||||
|
||||
AccountDNSDomainUpdated: {"Account DNS domain updated", "account.dns.domain.update"},
|
||||
|
||||
AccountLazyConnectionEnabled: {"Account lazy connection enabled", "account.setting.lazy.connection.enable"},
|
||||
AccountLazyConnectionDisabled: {"Account lazy connection disabled", "account.setting.lazy.connection.disable"},
|
||||
|
||||
AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"},
|
||||
|
||||
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
@@ -19,22 +19,22 @@ type Event struct {
|
||||
// Timestamp of the event
|
||||
Timestamp time.Time
|
||||
// Activity that was performed during the event
|
||||
Activity ActivityDescriber
|
||||
Activity Activity `gorm:"type:integer"`
|
||||
// ID of the event (can be empty, meaning that it wasn't yet generated)
|
||||
ID uint64
|
||||
ID uint64 `gorm:"primaryKey;autoIncrement"`
|
||||
// InitiatorID is the ID of an object that initiated the event (e.g., a user)
|
||||
InitiatorID string
|
||||
// InitiatorName is the name of an object that initiated the event.
|
||||
InitiatorName string
|
||||
InitiatorName string `gorm:"-"`
|
||||
// InitiatorEmail is the email address of an object that initiated the event.
|
||||
InitiatorEmail string
|
||||
InitiatorEmail string `gorm:"-"`
|
||||
// TargetID is the ID of an object that was effected by the event (e.g., a peer)
|
||||
TargetID string
|
||||
// AccountID is the ID of an account where the event happened
|
||||
AccountID string
|
||||
AccountID string `gorm:"index"`
|
||||
|
||||
// Meta of the event, e.g. deleted peer information like name, IP, etc
|
||||
Meta map[string]any
|
||||
Meta map[string]any `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// Copy the event
|
||||
@@ -57,3 +57,10 @@ func (e *Event) Copy() *Event {
|
||||
Meta: meta,
|
||||
}
|
||||
}
|
||||
|
||||
type DeletedUser struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
Email string `gorm:"not null"`
|
||||
Name string
|
||||
EncAlgo string `gorm:"not null"`
|
||||
}
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func migrate(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
|
||||
if _, err := db.Exec(createTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := db.Exec(creatTableDeletedUsersQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := updateDeletedUsersTable(ctx, db); err != nil {
|
||||
return fmt.Errorf("failed to update deleted_users table: %v", err)
|
||||
}
|
||||
|
||||
return migrateLegacyEncryptedUsersToGCM(ctx, crypt, db)
|
||||
}
|
||||
|
||||
// updateDeletedUsersTable checks and updates the deleted_users table schema to ensure required columns exist.
|
||||
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
|
||||
exists, err := checkColumnExists(db, "deleted_users", "name")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
log.WithContext(ctx).Debug("Adding name column to the deleted_users table")
|
||||
|
||||
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debug("Successfully added name column to the deleted_users table")
|
||||
}
|
||||
|
||||
exists, err = checkColumnExists(db, "deleted_users", "enc_algo")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
log.WithContext(ctx).Debug("Adding enc_algo column to the deleted_users table")
|
||||
|
||||
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN enc_algo TEXT;`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debug("Successfully added enc_algo column to the deleted_users table")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using,
|
||||
// legacy CBC encryption with a static IV to the new GCM encryption method.
|
||||
func migrateLegacyEncryptedUsersToGCM(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
|
||||
log.WithContext(ctx).Debug("Migrating CBC encrypted deleted users to GCM")
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
rows, err := tx.Query(fmt.Sprintf(`SELECT id, email, name FROM deleted_users where enc_algo IS NULL OR enc_algo != '%s'`, gcmEncAlgo))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute select query: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
updateStmt, err := tx.Prepare(`UPDATE deleted_users SET email = ?, name = ?, enc_algo = ? WHERE id = ?`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare update statement: %v", err)
|
||||
}
|
||||
defer updateStmt.Close()
|
||||
|
||||
if err = processUserRows(ctx, crypt, rows, updateStmt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %v", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debug("Successfully migrated CBC encrypted deleted users to GCM")
|
||||
return nil
|
||||
}
|
||||
|
||||
// processUserRows processes database rows of user data, decrypts legacy encryption fields, and re-encrypts them using GCM.
|
||||
func processUserRows(ctx context.Context, crypt *FieldEncrypt, rows *sql.Rows, updateStmt *sql.Stmt) error {
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, decryptedEmail, decryptedName string
|
||||
email, name *string
|
||||
)
|
||||
|
||||
err := rows.Scan(&id, &email, &name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if email != nil {
|
||||
decryptedEmail, err = crypt.LegacyDecrypt(*email)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
|
||||
id,
|
||||
fmt.Errorf("failed to decrypt email: %w", err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if name != nil {
|
||||
decryptedName, err = crypt.LegacyDecrypt(*name)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
|
||||
id,
|
||||
fmt.Errorf("failed to decrypt name: %w", err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
encryptedEmail, err := crypt.Encrypt(decryptedEmail)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt email: %w", err)
|
||||
}
|
||||
|
||||
encryptedName, err := crypt.Encrypt(decryptedName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt name: %w", err)
|
||||
}
|
||||
|
||||
_, err = updateStmt.Exec(encryptedEmail, encryptedName, gcmEncAlgo, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,84 +0,0 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupDatabase(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
|
||||
dbFile := filepath.Join(t.TempDir(), eventSinkDB)
|
||||
db, err := sql.Open("sqlite3", dbFile)
|
||||
require.NoError(t, err, "Failed to open database")
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = db.Close()
|
||||
})
|
||||
|
||||
_, err = db.Exec(createTableQuery)
|
||||
require.NoError(t, err, "Failed to create events table")
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`)
|
||||
require.NoError(t, err, "Failed to create deleted_users table")
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
key, err := GenerateKey()
|
||||
require.NoError(t, err, "Failed to generate key")
|
||||
|
||||
crypt, err := NewFieldEncrypt(key)
|
||||
require.NoError(t, err, "Failed to initialize FieldEncrypt")
|
||||
|
||||
legacyEmail := crypt.LegacyEncrypt("testaccount@test.com")
|
||||
legacyName := crypt.LegacyEncrypt("Test Account")
|
||||
|
||||
_, err = db.Exec(`INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) VALUES(?, ?, ?, ?, ?, ?)`,
|
||||
activity.UserDeleted, time.Now(), "initiatorID", "targetID", "accountID", "")
|
||||
require.NoError(t, err, "Failed to insert event")
|
||||
|
||||
_, err = db.Exec(`INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`, "targetID", legacyEmail, legacyName)
|
||||
require.NoError(t, err, "Failed to insert legacy encrypted data")
|
||||
|
||||
colExists, err := checkColumnExists(db, "deleted_users", "enc_algo")
|
||||
require.NoError(t, err, "Failed to check if enc_algo column exists")
|
||||
require.False(t, colExists, "enc_algo column should not exist before migration")
|
||||
|
||||
err = migrate(context.Background(), crypt, db)
|
||||
require.NoError(t, err, "Migration failed")
|
||||
|
||||
colExists, err = checkColumnExists(db, "deleted_users", "enc_algo")
|
||||
require.NoError(t, err, "Failed to check if enc_algo column exists after migration")
|
||||
require.True(t, colExists, "enc_algo column should exist after migration")
|
||||
|
||||
var encAlgo string
|
||||
err = db.QueryRow(`SELECT enc_algo FROM deleted_users LIMIT 1`, "").Scan(&encAlgo)
|
||||
require.NoError(t, err, "Failed to select updated data")
|
||||
require.Equal(t, gcmEncAlgo, encAlgo, "enc_algo should be set to 'GCM' after migration")
|
||||
|
||||
store, err := createStore(crypt, db)
|
||||
require.NoError(t, err, "Failed to create store")
|
||||
|
||||
events, err := store.Get(context.Background(), "accountID", 0, 1, false)
|
||||
require.NoError(t, err, "Failed to get events")
|
||||
|
||||
require.Len(t, events, 1, "Should have one event")
|
||||
require.Equal(t, activity.UserDeleted, events[0].Activity, "activity should match")
|
||||
require.Equal(t, "initiatorID", events[0].InitiatorID, "initiator id should match")
|
||||
require.Equal(t, "targetID", events[0].TargetID, "target id should match")
|
||||
require.Equal(t, "accountID", events[0].AccountID, "account id should match")
|
||||
require.Equal(t, "testaccount@test.com", events[0].Meta["email"], "email should match")
|
||||
require.Equal(t, "Test Account", events[0].Meta["username"], "username should match")
|
||||
}
|
||||
@@ -1,359 +0,0 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
)
|
||||
|
||||
const (
|
||||
// eventSinkDB is the default name of the events database
|
||||
eventSinkDB = "events.db"
|
||||
createTableQuery = "CREATE TABLE IF NOT EXISTS events " +
|
||||
"(id INTEGER PRIMARY KEY AUTOINCREMENT, " +
|
||||
"activity INTEGER, " +
|
||||
"timestamp DATETIME, " +
|
||||
"initiator_id TEXT," +
|
||||
"account_id TEXT," +
|
||||
"meta TEXT," +
|
||||
" target_id TEXT);"
|
||||
|
||||
creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT, enc_algo TEXT NOT NULL);`
|
||||
|
||||
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
|
||||
FROM events
|
||||
LEFT JOIN (
|
||||
SELECT id, MAX(name) as name, MAX(email) as email
|
||||
FROM deleted_users
|
||||
GROUP BY id
|
||||
) i ON events.initiator_id = i.id
|
||||
LEFT JOIN (
|
||||
SELECT id, MAX(name) as name, MAX(email) as email
|
||||
FROM deleted_users
|
||||
GROUP BY id
|
||||
) t ON events.target_id = t.id
|
||||
WHERE account_id = ?
|
||||
ORDER BY timestamp DESC LIMIT ? OFFSET ?;`
|
||||
|
||||
selectAscQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
|
||||
FROM events
|
||||
LEFT JOIN (
|
||||
SELECT id, MAX(name) as name, MAX(email) as email
|
||||
FROM deleted_users
|
||||
GROUP BY id
|
||||
) i ON events.initiator_id = i.id
|
||||
LEFT JOIN (
|
||||
SELECT id, MAX(name) as name, MAX(email) as email
|
||||
FROM deleted_users
|
||||
GROUP BY id
|
||||
) t ON events.target_id = t.id
|
||||
WHERE account_id = ?
|
||||
ORDER BY timestamp ASC LIMIT ? OFFSET ?;`
|
||||
|
||||
insertQuery = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " +
|
||||
"VALUES(?, ?, ?, ?, ?, ?)"
|
||||
|
||||
/*
|
||||
TODO:
|
||||
The insert should avoid duplicated IDs in the table. So the query should be changes to something like:
|
||||
`INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?) ON CONFLICT (id) DO UPDATE SET email = EXCLUDED.email, name = EXCLUDED.name;`
|
||||
For this to work we have to set the id column as primary key. But this is not possible because the id column is not unique
|
||||
and some selfhosted deployments might have duplicates already so we need to clean the table first.
|
||||
*/
|
||||
|
||||
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name, enc_algo) VALUES(?, ?, ?, ?)`
|
||||
|
||||
fallbackName = "unknown"
|
||||
fallbackEmail = "unknown@unknown.com"
|
||||
|
||||
gcmEncAlgo = "GCM"
|
||||
)
|
||||
|
||||
// Store is the implementation of the activity.Store interface backed by SQLite
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
fieldEncrypt *FieldEncrypt
|
||||
|
||||
insertStatement *sql.Stmt
|
||||
selectAscStatement *sql.Stmt
|
||||
selectDescStatement *sql.Stmt
|
||||
deleteUserStmt *sql.Stmt
|
||||
}
|
||||
|
||||
// NewSQLiteStore creates a new Store with an event table if not exists.
|
||||
func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) {
|
||||
dbFile := filepath.Join(dataDir, eventSinkDB)
|
||||
db, err := sql.Open("sqlite3", dbFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db.SetMaxOpenConns(runtime.NumCPU())
|
||||
|
||||
crypt, err := NewFieldEncrypt(encryptionKey)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = migrate(ctx, crypt, db); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("events database migration: %w", err)
|
||||
}
|
||||
|
||||
return createStore(crypt, db)
|
||||
}
|
||||
|
||||
func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
|
||||
events := make([]*activity.Event, 0)
|
||||
var cryptErr error
|
||||
for result.Next() {
|
||||
var id int64
|
||||
var operation activity.Activity
|
||||
var timestamp time.Time
|
||||
var initiator string
|
||||
var initiatorName *string
|
||||
var initiatorEmail *string
|
||||
var target string
|
||||
var targetUserName *string
|
||||
var targetEmail *string
|
||||
var account string
|
||||
var jsonMeta string
|
||||
err := result.Scan(&id, &operation, ×tamp, &initiator, &initiatorName, &initiatorEmail, &target, &targetUserName, &targetEmail, &account, &jsonMeta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
meta := make(map[string]any)
|
||||
if jsonMeta != "" {
|
||||
err = json.Unmarshal([]byte(jsonMeta), &meta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if targetUserName != nil {
|
||||
name, err := store.fieldEncrypt.Decrypt(*targetUserName)
|
||||
if err != nil {
|
||||
cryptErr = fmt.Errorf("failed to decrypt username for target id: %s", target)
|
||||
meta["username"] = fallbackName
|
||||
} else {
|
||||
meta["username"] = name
|
||||
}
|
||||
}
|
||||
|
||||
if targetEmail != nil {
|
||||
email, err := store.fieldEncrypt.Decrypt(*targetEmail)
|
||||
if err != nil {
|
||||
cryptErr = fmt.Errorf("failed to decrypt email address for target id: %s", target)
|
||||
meta["email"] = fallbackEmail
|
||||
} else {
|
||||
meta["email"] = email
|
||||
}
|
||||
}
|
||||
|
||||
event := &activity.Event{
|
||||
Timestamp: timestamp,
|
||||
Activity: operation,
|
||||
ID: uint64(id),
|
||||
InitiatorID: initiator,
|
||||
TargetID: target,
|
||||
AccountID: account,
|
||||
Meta: meta,
|
||||
}
|
||||
|
||||
if initiatorName != nil {
|
||||
name, err := store.fieldEncrypt.Decrypt(*initiatorName)
|
||||
if err != nil {
|
||||
cryptErr = fmt.Errorf("failed to decrypt username of initiator: %s", initiator)
|
||||
event.InitiatorName = fallbackName
|
||||
} else {
|
||||
event.InitiatorName = name
|
||||
}
|
||||
}
|
||||
|
||||
if initiatorEmail != nil {
|
||||
email, err := store.fieldEncrypt.Decrypt(*initiatorEmail)
|
||||
if err != nil {
|
||||
cryptErr = fmt.Errorf("failed to decrypt email address of initiator: %s", initiator)
|
||||
event.InitiatorEmail = fallbackEmail
|
||||
} else {
|
||||
event.InitiatorEmail = email
|
||||
}
|
||||
}
|
||||
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
if cryptErr != nil {
|
||||
log.WithContext(ctx).Warnf("%s", cryptErr)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// Get returns "limit" number of events from index ordered descending or ascending by a timestamp
|
||||
func (store *Store) Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
|
||||
stmt := store.selectDescStatement
|
||||
if !descending {
|
||||
stmt = store.selectAscStatement
|
||||
}
|
||||
|
||||
result, err := stmt.Query(accountID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer result.Close() //nolint
|
||||
return store.processResult(ctx, result)
|
||||
}
|
||||
|
||||
// Save an event in the SQLite events table end encrypt the "email" element in meta map
|
||||
func (store *Store) Save(_ context.Context, event *activity.Event) (*activity.Event, error) {
|
||||
var jsonMeta string
|
||||
meta, err := store.saveDeletedUserEmailAndNameInEncrypted(event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if meta != nil {
|
||||
metaBytes, err := json.Marshal(event.Meta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jsonMeta = string(metaBytes)
|
||||
}
|
||||
|
||||
result, err := store.insertStatement.Exec(event.Activity, event.Timestamp, event.InitiatorID, event.TargetID, event.AccountID, jsonMeta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
eventCopy := event.Copy()
|
||||
eventCopy.ID = uint64(id)
|
||||
return eventCopy, nil
|
||||
}
|
||||
|
||||
// saveDeletedUserEmailAndNameInEncrypted if the meta contains email and name then store it in encrypted way and delete
|
||||
// this item from meta map
|
||||
func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event) (map[string]any, error) {
|
||||
email, ok := event.Meta["email"]
|
||||
if !ok {
|
||||
return event.Meta, nil
|
||||
}
|
||||
|
||||
name, ok := event.Meta["name"]
|
||||
if !ok {
|
||||
return event.Meta, nil
|
||||
}
|
||||
|
||||
encryptedEmail, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encryptedName, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName, gcmEncAlgo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(event.Meta) == 2 {
|
||||
return nil, nil // nolint
|
||||
}
|
||||
delete(event.Meta, "email")
|
||||
delete(event.Meta, "name")
|
||||
return event.Meta, nil
|
||||
}
|
||||
|
||||
// Close the Store
|
||||
func (store *Store) Close(_ context.Context) error {
|
||||
if store.db != nil {
|
||||
return store.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createStore initializes and returns a new Store instance with prepared SQL statements.
|
||||
func createStore(crypt *FieldEncrypt, db *sql.DB) (*Store, error) {
|
||||
insertStmt, err := db.Prepare(insertQuery)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
selectDescStmt, err := db.Prepare(selectDescQuery)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
selectAscStmt, err := db.Prepare(selectAscQuery)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Store{
|
||||
db: db,
|
||||
fieldEncrypt: crypt,
|
||||
insertStatement: insertStmt,
|
||||
selectDescStatement: selectDescStmt,
|
||||
selectAscStatement: selectAscStmt,
|
||||
deleteUserStmt: deleteUserStmt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// checkColumnExists checks if a column exists in a specified table
|
||||
func checkColumnExists(db *sql.DB, tableName, columnName string) (bool, error) {
|
||||
query := fmt.Sprintf("PRAGMA table_info(%s);", tableName)
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to query table info: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var cid int
|
||||
var name, ctype string
|
||||
var notnull, pk int
|
||||
var dfltValue sql.NullString
|
||||
|
||||
err = rows.Scan(&cid, &name, &ctype, ¬null, &dfltValue, &pk)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to scan row: %w", err)
|
||||
}
|
||||
|
||||
if name == columnName {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package sqlite
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -1,4 +1,4 @@
|
||||
package sqlite
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
185
management/server/activity/store/migration.go
Normal file
185
management/server/activity/store/migration.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
)
|
||||
|
||||
func migrate(ctx context.Context, crypt *FieldEncrypt, db *gorm.DB) error {
|
||||
migrations := getMigrations(ctx, crypt)
|
||||
|
||||
for _, m := range migrations {
|
||||
if err := m(db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type migrationFunc func(*gorm.DB) error
|
||||
|
||||
func getMigrations(ctx context.Context, crypt *FieldEncrypt) []migrationFunc {
|
||||
return []migrationFunc{
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNewField[activity.DeletedUser](ctx, db, "name", "")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNewField[activity.DeletedUser](ctx, db, "enc_algo", "")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migrateLegacyEncryptedUsersToGCM(ctx, db, crypt)
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migrateDuplicateDeletedUsers(ctx, db)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using
|
||||
// legacy CBC encryption with a static IV to the new GCM encryption method.
|
||||
func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *FieldEncrypt) error {
|
||||
model := &activity.DeletedUser{}
|
||||
|
||||
if !db.Migrator().HasTable(model) {
|
||||
log.WithContext(ctx).Debugf("Table for %T does not exist, no CBC to GCM migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
var deletedUsers []activity.DeletedUser
|
||||
err := db.Model(model).Find(&deletedUsers, "enc_algo IS NULL OR enc_algo != ?", gcmEncAlgo).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query deleted_users: %w", err)
|
||||
}
|
||||
|
||||
if len(deletedUsers) == 0 {
|
||||
log.WithContext(ctx).Debug("No CBC encrypted deleted users to migrate")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = db.Transaction(func(tx *gorm.DB) error {
|
||||
for _, user := range deletedUsers {
|
||||
if err = updateDeletedUserData(tx, user, crypt); err != nil {
|
||||
return fmt.Errorf("failed to migrate deleted user %s: %w", user.ID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debug("Successfully migrated CBC encrypted deleted users to GCM")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateDeletedUserData(transaction *gorm.DB, user activity.DeletedUser, crypt *FieldEncrypt) error {
|
||||
var err error
|
||||
var decryptedEmail, decryptedName string
|
||||
|
||||
if user.Email != "" {
|
||||
decryptedEmail, err = crypt.LegacyDecrypt(user.Email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if user.Name != "" {
|
||||
decryptedName, err = crypt.LegacyDecrypt(user.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
updatedUser := user
|
||||
updatedUser.EncAlgo = gcmEncAlgo
|
||||
|
||||
updatedUser.Email, err = crypt.Encrypt(decryptedEmail)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt email: %w", err)
|
||||
}
|
||||
|
||||
updatedUser.Name, err = crypt.Encrypt(decryptedName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt name: %w", err)
|
||||
}
|
||||
|
||||
return transaction.Model(&updatedUser).Omit("id").Updates(updatedUser).Error
|
||||
}
|
||||
|
||||
// MigrateDuplicateDeletedUsers removes duplicates and ensures the id column is marked as the primary key
|
||||
func migrateDuplicateDeletedUsers(ctx context.Context, db *gorm.DB) error {
|
||||
model := &activity.DeletedUser{}
|
||||
if !db.Migrator().HasTable(model) {
|
||||
log.WithContext(ctx).Debugf("Table for %T does not exist, no duplicate migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
isPrimaryKey, err := isColumnPrimaryKey[activity.DeletedUser](db, "id")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isPrimaryKey {
|
||||
log.WithContext(ctx).Debug("No duplicate deleted users to migrate")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = db.Transaction(func(tx *gorm.DB) error {
|
||||
if err = tx.Migrator().RenameTable("deleted_users", "deleted_users_old"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = tx.Migrator().CreateTable(model); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var deletedUsers []activity.DeletedUser
|
||||
if err = tx.Table("deleted_users_old").Find(&deletedUsers).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, deletedUser := range deletedUsers {
|
||||
if err = tx.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"email", "name", "enc_algo"}),
|
||||
}).Create(&deletedUser).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Migrator().DropTable("deleted_users_old")
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debug("Successfully migrated duplicate deleted users")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isColumnPrimaryKey checks if a column is a primary key in the given model
|
||||
func isColumnPrimaryKey[T any](db *gorm.DB, columnName string) (bool, error) {
|
||||
var model T
|
||||
|
||||
cols, err := db.Migrator().ColumnTypes(&model)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, col := range cols {
|
||||
if col.Name() == columnName {
|
||||
isPrimaryKey, _ := col.PrimaryKey()
|
||||
return isPrimaryKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
143
management/server/activity/store/migration_test.go
Normal file
143
management/server/activity/store/migration_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
"github.com/netbirdio/netbird/management/server/testutil"
|
||||
)
|
||||
|
||||
const (
|
||||
insertDeletedUserQuery = `INSERT INTO deleted_users (id, email, name, enc_algo) VALUES (?, ?, ?, ?)`
|
||||
)
|
||||
|
||||
func setupDatabase(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
cleanup, dsn, err := testutil.CreatePostgresTestContainer()
|
||||
require.NoError(t, err, "Failed to create Postgres test container")
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn))
|
||||
require.NoError(t, err)
|
||||
|
||||
sql, err := db.DB()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = sql.Close()
|
||||
})
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestMigrateLegacyEncryptedUsersToGCM(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
key, err := GenerateKey()
|
||||
require.NoError(t, err, "Failed to generate key")
|
||||
|
||||
crypt, err := NewFieldEncrypt(key)
|
||||
require.NoError(t, err, "Failed to initialize FieldEncrypt")
|
||||
|
||||
t.Run("empty table, no migration required", func(t *testing.T) {
|
||||
require.NoError(t, migrateLegacyEncryptedUsersToGCM(context.Background(), db, crypt))
|
||||
assert.False(t, db.Migrator().HasTable("deleted_users"))
|
||||
})
|
||||
|
||||
require.NoError(t, db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`).Error)
|
||||
assert.True(t, db.Migrator().HasTable("deleted_users"))
|
||||
assert.False(t, db.Migrator().HasColumn("deleted_users", "enc_algo"))
|
||||
|
||||
require.NoError(t, migration.MigrateNewField[activity.DeletedUser](context.Background(), db, "enc_algo", ""))
|
||||
assert.True(t, db.Migrator().HasColumn("deleted_users", "enc_algo"))
|
||||
|
||||
t.Run("legacy users migration", func(t *testing.T) {
|
||||
legacyEmail := crypt.LegacyEncrypt("test.user@test.com")
|
||||
legacyName := crypt.LegacyEncrypt("Test User")
|
||||
|
||||
require.NoError(t, db.Exec(insertDeletedUserQuery, "user1", legacyEmail, legacyName, "").Error)
|
||||
require.NoError(t, db.Exec(insertDeletedUserQuery, "user2", legacyEmail, legacyName, "legacy").Error)
|
||||
|
||||
require.NoError(t, migrateLegacyEncryptedUsersToGCM(context.Background(), db, crypt))
|
||||
|
||||
var users []activity.DeletedUser
|
||||
require.NoError(t, db.Find(&users).Error)
|
||||
assert.Len(t, users, 2)
|
||||
|
||||
for _, user := range users {
|
||||
assert.Equal(t, gcmEncAlgo, user.EncAlgo)
|
||||
|
||||
decryptedEmail, err := crypt.Decrypt(user.Email)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test.user@test.com", decryptedEmail)
|
||||
|
||||
decryptedName, err := crypt.Decrypt(user.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Test User", decryptedName)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("users already migrated, no migration", func(t *testing.T) {
|
||||
encryptedEmail, err := crypt.Encrypt("test.user@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
encryptedName, err := crypt.Encrypt("Test User")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, db.Exec(insertDeletedUserQuery, "user3", encryptedEmail, encryptedName, gcmEncAlgo).Error)
|
||||
require.NoError(t, migrateLegacyEncryptedUsersToGCM(context.Background(), db, crypt))
|
||||
|
||||
var users []activity.DeletedUser
|
||||
require.NoError(t, db.Find(&users).Error)
|
||||
assert.Len(t, users, 3)
|
||||
|
||||
for _, user := range users {
|
||||
assert.Equal(t, gcmEncAlgo, user.EncAlgo)
|
||||
|
||||
decryptedEmail, err := crypt.Decrypt(user.Email)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test.user@test.com", decryptedEmail)
|
||||
|
||||
decryptedName, err := crypt.Decrypt(user.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Test User", decryptedName)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateDuplicateDeletedUsers(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
require.NoError(t, migrateDuplicateDeletedUsers(context.Background(), db))
|
||||
assert.False(t, db.Migrator().HasTable("deleted_users"))
|
||||
|
||||
require.NoError(t, db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT, enc_algo TEXT NOT NULL);`).Error)
|
||||
assert.True(t, db.Migrator().HasTable("deleted_users"))
|
||||
|
||||
isPrimaryKey, err := isColumnPrimaryKey[activity.DeletedUser](db, "id")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, isPrimaryKey)
|
||||
|
||||
require.NoError(t, db.Exec(insertDeletedUserQuery, "user1", "email1", "name1", "GCM").Error)
|
||||
require.NoError(t, db.Exec(insertDeletedUserQuery, "user1", "email2", "name2", "GCM").Error)
|
||||
require.NoError(t, migrateDuplicateDeletedUsers(context.Background(), db))
|
||||
|
||||
isPrimaryKey, err = isColumnPrimaryKey[activity.DeletedUser](db, "id")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, isPrimaryKey)
|
||||
|
||||
var users []activity.DeletedUser
|
||||
require.NoError(t, db.Find(&users).Error)
|
||||
assert.Len(t, users, 1)
|
||||
assert.Equal(t, "user1", users[0].ID)
|
||||
assert.Equal(t, "email2", users[0].Email)
|
||||
assert.Equal(t, "name2", users[0].Name)
|
||||
assert.Equal(t, "GCM", users[0].EncAlgo)
|
||||
}
|
||||
287
management/server/activity/store/sql_store.go
Normal file
287
management/server/activity/store/sql_store.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
// eventSinkDB is the default name of the events database
|
||||
eventSinkDB = "events.db"
|
||||
|
||||
fallbackName = "unknown"
|
||||
fallbackEmail = "unknown@unknown.com"
|
||||
|
||||
gcmEncAlgo = "GCM"
|
||||
|
||||
storeEngineEnv = "NB_ACTIVITY_EVENT_STORE_ENGINE"
|
||||
postgresDsnEnv = "NB_ACTIVITY_EVENT_POSTGRES_DSN"
|
||||
sqlMaxOpenConnsEnv = "NB_SQL_MAX_OPEN_CONNS"
|
||||
)
|
||||
|
||||
type eventWithNames struct {
|
||||
activity.Event
|
||||
InitiatorName string
|
||||
InitiatorEmail string
|
||||
TargetName string
|
||||
TargetEmail string
|
||||
}
|
||||
|
||||
// Store is the implementation of the activity.Store interface backed by SQLite
|
||||
type Store struct {
|
||||
db *gorm.DB
|
||||
fieldEncrypt *FieldEncrypt
|
||||
}
|
||||
|
||||
// NewSqlStore creates a new Store with an event table if not exists.
|
||||
func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) {
|
||||
crypt, err := NewFieldEncrypt(encryptionKey)
|
||||
if err != nil {
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, err := initDatabase(ctx, dataDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initialize database: %w", err)
|
||||
}
|
||||
|
||||
if err = migrate(ctx, crypt, db); err != nil {
|
||||
return nil, fmt.Errorf("events database migration: %w", err)
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&activity.Event{}, &activity.DeletedUser{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("events auto migrate: %w", err)
|
||||
}
|
||||
|
||||
return &Store{
|
||||
db: db,
|
||||
fieldEncrypt: crypt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (store *Store) processResult(ctx context.Context, events []*eventWithNames) ([]*activity.Event, error) {
|
||||
activityEvents := make([]*activity.Event, 0)
|
||||
var cryptErr error
|
||||
|
||||
for _, event := range events {
|
||||
e := event.Event
|
||||
if e.Meta == nil {
|
||||
e.Meta = make(map[string]any)
|
||||
}
|
||||
|
||||
if event.TargetName != "" {
|
||||
name, err := store.fieldEncrypt.Decrypt(event.TargetName)
|
||||
if err != nil {
|
||||
cryptErr = fmt.Errorf("failed to decrypt username for target id: %s", event.TargetName)
|
||||
e.Meta["username"] = fallbackName
|
||||
} else {
|
||||
e.Meta["username"] = name
|
||||
}
|
||||
}
|
||||
|
||||
if event.TargetEmail != "" {
|
||||
email, err := store.fieldEncrypt.Decrypt(event.TargetEmail)
|
||||
if err != nil {
|
||||
cryptErr = fmt.Errorf("failed to decrypt email address for target id: %s", event.TargetEmail)
|
||||
e.Meta["email"] = fallbackEmail
|
||||
} else {
|
||||
e.Meta["email"] = email
|
||||
}
|
||||
}
|
||||
|
||||
if event.InitiatorName != "" {
|
||||
name, err := store.fieldEncrypt.Decrypt(event.InitiatorName)
|
||||
if err != nil {
|
||||
cryptErr = fmt.Errorf("failed to decrypt username of initiator: %s", event.InitiatorName)
|
||||
e.InitiatorName = fallbackName
|
||||
} else {
|
||||
e.InitiatorName = name
|
||||
}
|
||||
}
|
||||
|
||||
if event.InitiatorEmail != "" {
|
||||
email, err := store.fieldEncrypt.Decrypt(event.InitiatorEmail)
|
||||
if err != nil {
|
||||
cryptErr = fmt.Errorf("failed to decrypt email address of initiator: %s", event.InitiatorEmail)
|
||||
e.InitiatorEmail = fallbackEmail
|
||||
} else {
|
||||
e.InitiatorEmail = email
|
||||
}
|
||||
}
|
||||
|
||||
activityEvents = append(activityEvents, &e)
|
||||
}
|
||||
|
||||
if cryptErr != nil {
|
||||
log.WithContext(ctx).Warnf("%s", cryptErr)
|
||||
}
|
||||
|
||||
return activityEvents, nil
|
||||
}
|
||||
|
||||
// Get returns "limit" number of events from index ordered descending or ascending by a timestamp
|
||||
func (store *Store) Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
|
||||
baseQuery := store.db.Model(&activity.Event{}).
|
||||
Select(`
|
||||
events.*,
|
||||
u.name AS initiator_name,
|
||||
u.email AS initiator_email,
|
||||
t.name AS target_name,
|
||||
t.email AS target_email
|
||||
`).
|
||||
Joins(`LEFT JOIN deleted_users u ON u.id = events.initiator_id`).
|
||||
Joins(`LEFT JOIN deleted_users t ON t.id = events.target_id`)
|
||||
|
||||
orderDir := "DESC"
|
||||
if !descending {
|
||||
orderDir = "ASC"
|
||||
}
|
||||
|
||||
var events []*eventWithNames
|
||||
err := baseQuery.Order("events.timestamp "+orderDir).Offset(offset).Limit(limit).
|
||||
Find(&events, "account_id = ?", accountID).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return store.processResult(ctx, events)
|
||||
}
|
||||
|
||||
// Save an event in the SQLite events table end encrypt the "email" element in meta map
|
||||
func (store *Store) Save(_ context.Context, event *activity.Event) (*activity.Event, error) {
|
||||
eventCopy := event.Copy()
|
||||
meta, err := store.saveDeletedUserEmailAndNameInEncrypted(eventCopy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eventCopy.Meta = meta
|
||||
|
||||
if err = store.db.Create(eventCopy).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return eventCopy, nil
|
||||
}
|
||||
|
||||
// saveDeletedUserEmailAndNameInEncrypted if the meta contains email and name then store it in encrypted way and delete
|
||||
// this item from meta map
|
||||
func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event) (map[string]any, error) {
|
||||
email, ok := event.Meta["email"]
|
||||
if !ok {
|
||||
return event.Meta, nil
|
||||
}
|
||||
|
||||
name, ok := event.Meta["name"]
|
||||
if !ok {
|
||||
return event.Meta, nil
|
||||
}
|
||||
|
||||
deletedUser := activity.DeletedUser{
|
||||
ID: event.TargetID,
|
||||
EncAlgo: gcmEncAlgo,
|
||||
}
|
||||
|
||||
encryptedEmail, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deletedUser.Email = encryptedEmail
|
||||
|
||||
encryptedName, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deletedUser.Name = encryptedName
|
||||
|
||||
err = store.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"email", "name"}),
|
||||
}).Create(deletedUser).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(event.Meta) == 2 {
|
||||
return nil, nil // nolint
|
||||
}
|
||||
delete(event.Meta, "email")
|
||||
delete(event.Meta, "name")
|
||||
return event.Meta, nil
|
||||
}
|
||||
|
||||
// Close the Store
|
||||
func (store *Store) Close(_ context.Context) error {
|
||||
if store.db != nil {
|
||||
sql, err := store.db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sql.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func initDatabase(ctx context.Context, dataDir string) (*gorm.DB, error) {
|
||||
var dialector gorm.Dialector
|
||||
var storeEngine = types.SqliteStoreEngine
|
||||
|
||||
if engine, ok := os.LookupEnv(storeEngineEnv); ok {
|
||||
storeEngine = types.Engine(engine)
|
||||
}
|
||||
|
||||
switch storeEngine {
|
||||
case types.SqliteStoreEngine:
|
||||
dialector = sqlite.Open(filepath.Join(dataDir, eventSinkDB))
|
||||
case types.PostgresStoreEngine:
|
||||
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s environment variable not set", postgresDsnEnv)
|
||||
}
|
||||
dialector = postgres.Open(dsn)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported store engine: %s", storeEngine)
|
||||
}
|
||||
log.WithContext(ctx).Infof("using %s as activity event store engine", storeEngine)
|
||||
|
||||
db, err := gorm.Open(dialector, &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open db connection: %w", err)
|
||||
}
|
||||
|
||||
return configureConnectionPool(db, storeEngine)
|
||||
}
|
||||
|
||||
func configureConnectionPool(db *gorm.DB, storeEngine types.Engine) (*gorm.DB, error) {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if storeEngine == types.SqliteStoreEngine {
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
} else {
|
||||
conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv))
|
||||
if err != nil {
|
||||
conns = runtime.NumCPU()
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(conns)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package sqlite
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
)
|
||||
|
||||
func TestNewSQLiteStore(t *testing.T) {
|
||||
func TestNewSqlStore(t *testing.T) {
|
||||
dataDir := t.TempDir()
|
||||
key, _ := GenerateKey()
|
||||
store, err := NewSQLiteStore(context.Background(), dataDir, key)
|
||||
store, err := NewSqlStore(context.Background(), dataDir, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -73,7 +73,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco
|
||||
return userAuth, nil
|
||||
}
|
||||
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
|
||||
if err != nil {
|
||||
return userAuth, err
|
||||
}
|
||||
@@ -94,7 +94,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco
|
||||
|
||||
// MarkPATUsed marks a personal access token as used
|
||||
func (am *manager) MarkPATUsed(ctx context.Context, tokenID string) error {
|
||||
return am.store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID)
|
||||
return am.store.MarkPATUsed(ctx, tokenID)
|
||||
}
|
||||
|
||||
// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
|
||||
@@ -104,7 +104,7 @@ func (am *manager) GetPATInfo(ctx context.Context, token string) (user *types.Us
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
|
||||
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, user.AccountID)
|
||||
if err != nil {
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
@@ -142,12 +142,12 @@ func (am *manager) extractPATFromToken(ctx context.Context, token string) (*type
|
||||
var pat *types.PersonalAccessToken
|
||||
|
||||
err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
|
||||
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthNone, encodedHashedToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
|
||||
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthNone, pat.ID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package context
|
||||
|
||||
import "github.com/netbirdio/netbird/shared/context"
|
||||
|
||||
const (
|
||||
RequestIDKey = "requestID"
|
||||
AccountIDKey = "accountID"
|
||||
UserIDKey = "userID"
|
||||
PeerIDKey = "peerID"
|
||||
RequestIDKey = context.RequestIDKey
|
||||
AccountIDKey = context.AccountIDKey
|
||||
UserIDKey = context.UserIDKey
|
||||
PeerIDKey = context.PeerIDKey
|
||||
)
|
||||
|
||||
@@ -8,14 +8,14 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
||||
@@ -72,7 +72,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||
@@ -113,11 +113,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveDNSSettings(ctx, store.LockingStrengthUpdate, accountID, dnsSettingsToSave)
|
||||
return transaction.SaveDNSSettings(ctx, accountID, dnsSettingsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -139,7 +139,7 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
|
||||
var eventsToStore []func()
|
||||
|
||||
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
|
||||
return nil
|
||||
@@ -195,7 +195,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, settings.DisabledManagementGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -216,8 +216,10 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
// return empty extra settings for expected calls to UpdateAccountPeers
|
||||
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
}
|
||||
|
||||
func createDNSStore(t *testing.T) (store.Store, error) {
|
||||
@@ -267,7 +269,7 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
|
||||
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain, false)
|
||||
|
||||
account.Users[dnsRegularUserID] = &types.User{
|
||||
Id: dnsRegularUserID,
|
||||
@@ -493,7 +495,7 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
@@ -504,7 +506,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
@@ -560,11 +562,11 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
// Creating DNS settings with groups that have peers should update account peers and send peer update
|
||||
t.Run("creating dns setting with used groups", func(t *testing.T) {
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
|
||||
const (
|
||||
ephemeralLifeTime = 10 * time.Minute
|
||||
// cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure.
|
||||
cleanupWindow = 1 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -41,6 +43,9 @@ type EphemeralManager struct {
|
||||
tailPeer *ephemeralPeer
|
||||
peersLock sync.Mutex
|
||||
timer *time.Timer
|
||||
|
||||
lifeTime time.Duration
|
||||
cleanupWindow time.Duration
|
||||
}
|
||||
|
||||
// NewEphemeralManager instantiate new EphemeralManager
|
||||
@@ -48,6 +53,9 @@ func NewEphemeralManager(store store.Store, accountManager nbAccount.Manager) *E
|
||||
return &EphemeralManager{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
|
||||
lifeTime: ephemeralLifeTime,
|
||||
cleanupWindow: cleanupWindow,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,7 +68,7 @@ func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) {
|
||||
|
||||
e.loadEphemeralPeers(ctx)
|
||||
if e.headPeer != nil {
|
||||
e.timer = time.AfterFunc(ephemeralLifeTime, func() {
|
||||
e.timer = time.AfterFunc(e.lifeTime, func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
}
|
||||
@@ -113,22 +121,26 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
||||
return
|
||||
}
|
||||
|
||||
e.addPeer(peer.AccountID, peer.ID, newDeadLine())
|
||||
e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
|
||||
if e.timer == nil {
|
||||
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
|
||||
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
|
||||
if delay < 0 {
|
||||
delay = 0
|
||||
}
|
||||
e.timer = time.AfterFunc(delay, func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
||||
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthShare)
|
||||
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
t := newDeadLine()
|
||||
t := e.newDeadLine()
|
||||
for _, p := range peers {
|
||||
e.addPeer(p.AccountID, p.ID, t)
|
||||
}
|
||||
@@ -155,7 +167,11 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
}
|
||||
|
||||
if e.headPeer != nil {
|
||||
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
|
||||
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
|
||||
if delay < 0 {
|
||||
delay = 0
|
||||
}
|
||||
e.timer = time.AfterFunc(delay, func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
} else {
|
||||
@@ -164,13 +180,20 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
|
||||
e.peersLock.Unlock()
|
||||
|
||||
bufferAccountCall := make(map[string]struct{})
|
||||
|
||||
for id, p := range deletePeers {
|
||||
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
|
||||
err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
|
||||
} else {
|
||||
bufferAccountCall[p.accountID] = struct{}{}
|
||||
}
|
||||
}
|
||||
for accountID := range bufferAccountCall {
|
||||
e.accountManager.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
|
||||
@@ -223,6 +246,6 @@ func (e *EphemeralManager) isPeerOnList(id string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func newDeadLine() time.Time {
|
||||
return timeNow().Add(ephemeralLifeTime)
|
||||
func (e *EphemeralManager) newDeadLine() time.Time {
|
||||
return timeNow().Add(e.lifeTime)
|
||||
}
|
||||
|
||||
@@ -3,9 +3,12 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -27,28 +30,65 @@ func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ store.LockingStren
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
type MocAccountManager struct {
|
||||
type MockAccountManager struct {
|
||||
mu sync.Mutex
|
||||
nbAccount.Manager
|
||||
store *MockStore
|
||||
store *MockStore
|
||||
deletePeerCalls int
|
||||
bufferUpdateCalls map[string]int
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
|
||||
func (a *MockAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.deletePeerCalls++
|
||||
delete(a.store.account.Peers, peerID)
|
||||
return nil //nolint:nil
|
||||
if a.wg != nil {
|
||||
a.wg.Done()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a MocAccountManager) GetStore() store.Store {
|
||||
func (a *MockAccountManager) GetDeletePeerCalls() int {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
return a.deletePeerCalls
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if a.bufferUpdateCalls == nil {
|
||||
a.bufferUpdateCalls = make(map[string]int)
|
||||
}
|
||||
a.bufferUpdateCalls[accountID]++
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) GetBufferUpdateCalls(accountID string) int {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if a.bufferUpdateCalls == nil {
|
||||
return 0
|
||||
}
|
||||
return a.bufferUpdateCalls[accountID]
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) GetStore() store.Store {
|
||||
return a.store
|
||||
}
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
timeNow = time.Now
|
||||
})
|
||||
startTime := time.Now()
|
||||
timeNow = func() time.Time {
|
||||
return startTime
|
||||
}
|
||||
|
||||
store := &MockStore{}
|
||||
am := MocAccountManager{
|
||||
am := MockAccountManager{
|
||||
store: store,
|
||||
}
|
||||
|
||||
@@ -56,7 +96,7 @@ func TestNewManager(t *testing.T) {
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr := NewEphemeralManager(store, &am)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
||||
mgr.cleanup(context.Background())
|
||||
@@ -67,13 +107,16 @@ func TestNewManager(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewManagerPeerConnected(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
timeNow = time.Now
|
||||
})
|
||||
startTime := time.Now()
|
||||
timeNow = func() time.Time {
|
||||
return startTime
|
||||
}
|
||||
|
||||
store := &MockStore{}
|
||||
am := MocAccountManager{
|
||||
am := MockAccountManager{
|
||||
store: store,
|
||||
}
|
||||
|
||||
@@ -81,7 +124,7 @@ func TestNewManagerPeerConnected(t *testing.T) {
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr := NewEphemeralManager(store, &am)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
||||
|
||||
@@ -95,13 +138,16 @@ func TestNewManagerPeerConnected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
timeNow = time.Now
|
||||
})
|
||||
startTime := time.Now()
|
||||
timeNow = func() time.Time {
|
||||
return startTime
|
||||
}
|
||||
|
||||
store := &MockStore{}
|
||||
am := MocAccountManager{
|
||||
am := MockAccountManager{
|
||||
store: store,
|
||||
}
|
||||
|
||||
@@ -109,7 +155,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr := NewEphemeralManager(store, &am)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
for _, v := range store.account.Peers {
|
||||
mgr.OnPeerConnected(context.Background(), v)
|
||||
@@ -126,8 +172,38 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
|
||||
const (
|
||||
ephemeralPeers = 10
|
||||
testLifeTime = 1 * time.Second
|
||||
testCleanupWindow = 100 * time.Millisecond
|
||||
)
|
||||
mockStore := &MockStore{}
|
||||
mockAM := &MockAccountManager{
|
||||
store: mockStore,
|
||||
}
|
||||
mockAM.wg = &sync.WaitGroup{}
|
||||
mockAM.wg.Add(ephemeralPeers)
|
||||
mgr := NewEphemeralManager(mockStore, mockAM)
|
||||
mgr.lifeTime = testLifeTime
|
||||
mgr.cleanupWindow = testCleanupWindow
|
||||
|
||||
account := newAccountWithId(context.Background(), "account", "", "", false)
|
||||
mockStore.account = account
|
||||
for i := range ephemeralPeers {
|
||||
p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true}
|
||||
mockStore.account.Peers[p.ID] = p
|
||||
time.Sleep(testCleanupWindow / ephemeralPeers)
|
||||
mgr.OnPeerDisconnected(context.Background(), p)
|
||||
}
|
||||
mockAM.wg.Wait()
|
||||
assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime")
|
||||
assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once")
|
||||
assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers")
|
||||
}
|
||||
|
||||
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
|
||||
store.account = newAccountWithId(context.Background(), "my account", "", "")
|
||||
store.account = newAccountWithId(context.Background(), "my account", "", "", false)
|
||||
|
||||
for i := 0; i < numberOfPeers; i++ {
|
||||
peerId := fmt.Sprintf("peer_%d", i)
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func isEnabled() bool {
|
||||
@@ -66,7 +66,7 @@ func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, ta
|
||||
go func() {
|
||||
_, err := am.eventStore.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activityID,
|
||||
Activity: activityID.(activity.Activity),
|
||||
InitiatorID: initiatorID,
|
||||
TargetID: targetID,
|
||||
AccountID: accountID,
|
||||
@@ -103,7 +103,7 @@ func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, eve
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) {
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountId)
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -143,11 +143,10 @@ func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events [
|
||||
return eventUserInfos, nil
|
||||
}
|
||||
|
||||
return am.getEventsExternalUserInfo(ctx, externalUserIds, eventUserInfos, userId)
|
||||
return am.getEventsExternalUserInfo(ctx, externalUserIds, eventUserInfos)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, externalUserIds []string, eventUserInfos map[string]eventUserInfo, userId string) (map[string]eventUserInfo, error) {
|
||||
externalAccountId := ""
|
||||
func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, externalUserIds []string, eventUserInfos map[string]eventUserInfo) (map[string]eventUserInfo, error) {
|
||||
fetched := make(map[string]struct{})
|
||||
externalUsers := []*types.User{}
|
||||
for _, id := range externalUserIds {
|
||||
@@ -155,40 +154,36 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
|
||||
continue
|
||||
}
|
||||
|
||||
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
|
||||
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
|
||||
if err != nil {
|
||||
// @todo consider logging
|
||||
continue
|
||||
}
|
||||
|
||||
if externalAccountId != "" && externalAccountId != externalUser.AccountID {
|
||||
return nil, fmt.Errorf("multiple external user accounts in events")
|
||||
}
|
||||
|
||||
if externalAccountId == "" {
|
||||
externalAccountId = externalUser.AccountID
|
||||
}
|
||||
|
||||
fetched[id] = struct{}{}
|
||||
externalUsers = append(externalUsers, externalUser)
|
||||
}
|
||||
|
||||
// if we couldn't determine an account, return what we have
|
||||
if externalAccountId == "" {
|
||||
log.WithContext(ctx).Warnf("failed to determine external user account from users: %v", externalUserIds)
|
||||
return eventUserInfos, nil
|
||||
usersByExternalAccount := map[string][]*types.User{}
|
||||
for _, u := range externalUsers {
|
||||
if _, ok := usersByExternalAccount[u.AccountID]; !ok {
|
||||
usersByExternalAccount[u.AccountID] = make([]*types.User, 0)
|
||||
}
|
||||
usersByExternalAccount[u.AccountID] = append(usersByExternalAccount[u.AccountID], u)
|
||||
}
|
||||
|
||||
externalUserInfos, err := am.BuildUserInfosForAccount(ctx, externalAccountId, userId, externalUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for externalAccountId, externalUsers := range usersByExternalAccount {
|
||||
externalUserInfos, err := am.BuildUserInfosForAccount(ctx, externalAccountId, "", externalUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, k := range externalUserInfos {
|
||||
eventUserInfos[i] = eventUserInfo{
|
||||
email: k.Email,
|
||||
name: k.Name,
|
||||
accountId: externalAccountId,
|
||||
for i, k := range externalUserInfos {
|
||||
eventUserInfos[i] = eventUserInfo{
|
||||
email: k.Email,
|
||||
name: k.Name,
|
||||
accountId: externalAccountId,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type GeoNames struct {
|
||||
|
||||
@@ -14,11 +14,11 @@ import (
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type GroupLinkError struct {
|
||||
@@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
|
||||
return am.Store.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
}
|
||||
|
||||
// GetAllGroups returns all groups in an account
|
||||
@@ -57,30 +57,152 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName)
|
||||
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
}
|
||||
|
||||
// SaveGroup object of the peers
|
||||
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *types.Group, create bool) error {
|
||||
// CreateGroup object of the peers
|
||||
func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
return am.SaveGroups(ctx, accountID, userID, []*types.Group{newGroup}, create)
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveGroups adds new groups to the account.
|
||||
// UpdateGroup object of the peers
|
||||
func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
|
||||
}
|
||||
|
||||
peersToAdd := util.Difference(newGroup.Peers, oldGroup.Peers)
|
||||
peersToRemove := util.Difference(oldGroup.Peers, newGroup.Peers)
|
||||
|
||||
for _, peerID := range peersToAdd {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err)
|
||||
}
|
||||
}
|
||||
for _, peerID := range peersToRemove {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, newGroup.ID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, newGroup.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.UpdateGroup(ctx, newGroup)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateGroups adds new groups to the account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error {
|
||||
operation := operations.Create
|
||||
if !create {
|
||||
operation = operations.Update
|
||||
}
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operation)
|
||||
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
|
||||
func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -112,11 +234,69 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroups(ctx, store.LockingStrengthUpdate, groupsToSave)
|
||||
return transaction.CreateGroups(ctx, accountID, groupsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateGroups updates groups in the account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
|
||||
func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var groupsToSave []*types.Group
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
groupsToSave = append(groupsToSave, newGroup)
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.UpdateGroups(ctx, accountID, groupsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -140,7 +320,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
|
||||
addedPeers := make([]string, 0)
|
||||
removedPeers := make([]string, 0)
|
||||
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID)
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
|
||||
if err == nil && oldGroup != nil {
|
||||
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
|
||||
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
|
||||
@@ -152,13 +332,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
|
||||
}
|
||||
|
||||
modifiedPeers := slices.Concat(addedPeers, removedPeers)
|
||||
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers)
|
||||
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, modifiedPeers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
|
||||
return nil
|
||||
@@ -243,11 +423,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeleteGroups(ctx, store.LockingStrengthUpdate, accountID, groupIDsToDelete)
|
||||
return transaction.DeleteGroups(ctx, accountID, groupIDsToDelete)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -265,30 +445,20 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated := group.AddPeer(peerID); !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
return transaction.AddPeerToGroup(ctx, accountID, peerID, groupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -325,11 +495,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
return transaction.UpdateGroup(ctx, group)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -347,30 +517,20 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated := group.RemovePeer(peerID); !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
return transaction.RemovePeerFromGroup(ctx, peerID, groupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -407,11 +567,11 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
return transaction.UpdateGroup(ctx, group)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -431,7 +591,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
|
||||
}
|
||||
|
||||
if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI {
|
||||
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name)
|
||||
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthNone, accountID, newGroup.Name)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
|
||||
return err
|
||||
@@ -448,7 +608,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
|
||||
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||
}
|
||||
@@ -460,7 +620,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
|
||||
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
|
||||
// disable a deleting integration group if the initiator is not an admin service user
|
||||
if group.Issued == types.GroupIssuedIntegration {
|
||||
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to get user")
|
||||
}
|
||||
@@ -506,7 +666,7 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
|
||||
|
||||
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
|
||||
func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error {
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID)
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, group.AccountID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to get DNS settings")
|
||||
}
|
||||
@@ -515,7 +675,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
|
||||
return &GroupLinkError{"disabled DNS management groups", group.Name}
|
||||
}
|
||||
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID)
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, group.AccountID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to get account settings")
|
||||
}
|
||||
@@ -529,7 +689,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
|
||||
|
||||
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
|
||||
func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) {
|
||||
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
||||
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -549,7 +709,7 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI
|
||||
|
||||
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
|
||||
func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -567,7 +727,7 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, account
|
||||
|
||||
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
|
||||
func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -586,7 +746,7 @@ func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID
|
||||
|
||||
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
|
||||
func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) {
|
||||
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID)
|
||||
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -602,7 +762,7 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accou
|
||||
|
||||
// isGroupLinkedToUser checks if a group is linked to any user in the account.
|
||||
func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) {
|
||||
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
||||
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -618,7 +778,7 @@ func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID
|
||||
|
||||
// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account.
|
||||
func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) {
|
||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
|
||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -638,7 +798,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
|
||||
return false, nil
|
||||
}
|
||||
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -664,18 +824,9 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
|
||||
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -2,14 +2,20 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -18,10 +24,12 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
peer2 "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -40,7 +48,8 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
}
|
||||
for _, group := range account.Groups {
|
||||
group.Issued = types.GroupIssuedIntegration
|
||||
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true)
|
||||
group.ID = uuid.New().String()
|
||||
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group)
|
||||
if err != nil {
|
||||
t.Errorf("should allow to create %s groups", types.GroupIssuedIntegration)
|
||||
}
|
||||
@@ -48,7 +57,8 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
|
||||
for _, group := range account.Groups {
|
||||
group.Issued = types.GroupIssuedJWT
|
||||
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true)
|
||||
group.ID = uuid.New().String()
|
||||
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group)
|
||||
if err != nil {
|
||||
t.Errorf("should allow to create %s groups", types.GroupIssuedJWT)
|
||||
}
|
||||
@@ -56,7 +66,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
for _, group := range account.Groups {
|
||||
group.Issued = types.GroupIssuedAPI
|
||||
group.ID = ""
|
||||
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true)
|
||||
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group)
|
||||
if err == nil {
|
||||
t.Errorf("should not create api group with the same name, %s", group.Name)
|
||||
}
|
||||
@@ -162,7 +172,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups, true)
|
||||
err = manager.CreateGroups(context.Background(), account.Id, groupAdminUserID, groups)
|
||||
assert.NoError(t, err, "Failed to save test groups")
|
||||
|
||||
testCases := []struct {
|
||||
@@ -369,7 +379,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
|
||||
Id: "example user",
|
||||
AutoGroups: []string{groupForUsers.ID},
|
||||
}
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
|
||||
account.Routes[routeResource.ID] = routeResource
|
||||
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
|
||||
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
|
||||
@@ -382,13 +392,13 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute, true)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2, true)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups, true)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies, true)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys, true)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers, true)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration, true)
|
||||
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
|
||||
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2)
|
||||
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups)
|
||||
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies)
|
||||
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
|
||||
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
|
||||
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
|
||||
|
||||
acc, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
@@ -400,7 +410,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
|
||||
func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
g := []*types.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
@@ -426,8 +436,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
Name: "GroupE",
|
||||
Peers: []string{peer2.ID},
|
||||
},
|
||||
}, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
for _, group := range g {
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, group)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
@@ -442,11 +455,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -513,7 +526,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
|
||||
// adding a group to policy
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
@@ -535,11 +548,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -604,11 +617,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -645,11 +658,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -672,11 +685,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupD",
|
||||
Name: "GroupD",
|
||||
Peers: []string{peer1.ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -719,11 +732,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupE",
|
||||
Name: "GroupE",
|
||||
Peers: []string{peer2.ID, peer3.ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -733,3 +746,259 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AddPeerToGroup(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
acc, err := createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-start
|
||||
|
||||
err = manager.Store.AddPeerToGroup(context.Background(), accountID, strconv.Itoa(i), acc.GroupsG[0].ID)
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
|
||||
}
|
||||
|
||||
func Test_AddPeerToAll(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-start
|
||||
|
||||
err = manager.Store.AddPeerToAllGroup(context.Background(), accountID, strconv.Itoa(i))
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
|
||||
}
|
||||
|
||||
func Test_AddPeerAndAddToAll(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-start
|
||||
|
||||
peer := &peer2.Peer{
|
||||
ID: strconv.Itoa(i),
|
||||
AccountID: accountID,
|
||||
DNSLabel: "peer" + strconv.Itoa(i),
|
||||
IP: uint32ToIP(uint32(i)),
|
||||
}
|
||||
|
||||
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
|
||||
err = transaction.AddPeerToAccount(context.Background(), peer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
}
|
||||
err = transaction.AddPeerToAllGroup(context.Background(), accountID, peer.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("AddPeer failed for peer %d: %v", i, err)
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
|
||||
assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers))
|
||||
}
|
||||
|
||||
func uint32ToIP(n uint32) net.IP {
|
||||
ip := make(net.IP, 4)
|
||||
binary.BigEndian.PutUint32(ip, n)
|
||||
return ip
|
||||
}
|
||||
|
||||
func Test_IncrementNetworkSerial(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-start
|
||||
|
||||
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
|
||||
err = transaction.IncrementNetworkSerial(context.Background(), accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("AddPeer failed for peer %d: %v", i, err)
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, int(account.Network.Serial), "Expected %d serial increases in account %s, got %d", totalPeers, accountID, account.Network.Serial)
|
||||
}
|
||||
|
||||
@@ -6,12 +6,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
@@ -49,7 +49,7 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
@@ -96,13 +96,13 @@ func (m *managerImpl) AddResourceToGroupInTransaction(ctx context.Context, trans
|
||||
return nil, fmt.Errorf("error adding resource to group: %w", err)
|
||||
}
|
||||
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting group: %w", err)
|
||||
}
|
||||
|
||||
// TODO: at some point, this will need to become a switch statement
|
||||
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resource.ID)
|
||||
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resource.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting network resource: %w", err)
|
||||
}
|
||||
@@ -120,13 +120,13 @@ func (m *managerImpl) RemoveResourceFromGroupInTransaction(ctx context.Context,
|
||||
return nil, fmt.Errorf("error removing resource from group: %w", err)
|
||||
}
|
||||
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting group: %w", err)
|
||||
}
|
||||
|
||||
// TODO: at some point, this will need to become a switch statement
|
||||
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID)
|
||||
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting network resource: %w", err)
|
||||
}
|
||||
|
||||
@@ -20,8 +20,10 @@ import (
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
@@ -29,9 +31,10 @@ import (
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
internalStatus "github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
internalStatus "github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// GRPCServer an instance of a Management gRPC API server
|
||||
@@ -40,13 +43,14 @@ type GRPCServer struct {
|
||||
settingsManager settings.Manager
|
||||
wgKey wgtypes.Key
|
||||
proto.UnimplementedManagementServiceServer
|
||||
peersUpdateManager *PeersUpdateManager
|
||||
config *types.Config
|
||||
secretsManager SecretsManager
|
||||
appMetrics telemetry.AppMetrics
|
||||
ephemeralManager *EphemeralManager
|
||||
peerLocks sync.Map
|
||||
authManager auth.Manager
|
||||
peersUpdateManager *PeersUpdateManager
|
||||
config *types.Config
|
||||
secretsManager SecretsManager
|
||||
appMetrics telemetry.AppMetrics
|
||||
ephemeralManager *EphemeralManager
|
||||
peerLocks sync.Map
|
||||
authManager auth.Manager
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@@ -60,6 +64,7 @@ func NewServer(
|
||||
appMetrics telemetry.AppMetrics,
|
||||
ephemeralManager *EphemeralManager,
|
||||
authManager auth.Manager,
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||
) (*GRPCServer, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
@@ -79,14 +84,15 @@ func NewServer(
|
||||
return &GRPCServer{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
accountManager: accountManager,
|
||||
settingsManager: settingsManager,
|
||||
config: config,
|
||||
secretsManager: secretsManager,
|
||||
authManager: authManager,
|
||||
appMetrics: appMetrics,
|
||||
ephemeralManager: ephemeralManager,
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
accountManager: accountManager,
|
||||
settingsManager: settingsManager,
|
||||
config: config,
|
||||
secretsManager: secretsManager,
|
||||
authManager: authManager,
|
||||
appMetrics: appMetrics,
|
||||
ephemeralManager: ephemeralManager,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -392,6 +398,18 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
|
||||
Cloud: meta.GetEnvironment().GetCloud(),
|
||||
Platform: meta.GetEnvironment().GetPlatform(),
|
||||
},
|
||||
Flags: nbpeer.Flags{
|
||||
RosenpassEnabled: meta.GetFlags().GetRosenpassEnabled(),
|
||||
RosenpassPermissive: meta.GetFlags().GetRosenpassPermissive(),
|
||||
ServerSSHAllowed: meta.GetFlags().GetServerSSHAllowed(),
|
||||
DisableClientRoutes: meta.GetFlags().GetDisableClientRoutes(),
|
||||
DisableServerRoutes: meta.GetFlags().GetDisableServerRoutes(),
|
||||
DisableDNS: meta.GetFlags().GetDisableDNS(),
|
||||
DisableFirewall: meta.GetFlags().GetDisableFirewall(),
|
||||
BlockLANAccess: meta.GetFlags().GetBlockLANAccess(),
|
||||
BlockInbound: meta.GetFlags().GetBlockInbound(),
|
||||
LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(),
|
||||
},
|
||||
Files: files,
|
||||
}
|
||||
}
|
||||
@@ -517,7 +535,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer
|
||||
// if peer has reached this point then it has logged in
|
||||
loginResp := &proto.LoginResponse{
|
||||
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), false),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), settings),
|
||||
Checks: toProtocolChecks(ctx, postureChecks),
|
||||
}
|
||||
|
||||
@@ -632,20 +650,21 @@ func toNetbirdConfig(config *types.Config, turnCredentials *Token, relayToken *T
|
||||
return nbConfig
|
||||
}
|
||||
|
||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, dnsResolutionOnRoutingPeerEnabled bool) *proto.PeerConfig {
|
||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig {
|
||||
netmask, _ := network.Net.Mask.Size()
|
||||
fqdn := peer.FQDN(dnsName)
|
||||
return &proto.PeerConfig{
|
||||
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network
|
||||
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
|
||||
Fqdn: fqdn,
|
||||
RoutingPeerDnsResolutionEnabled: dnsResolutionOnRoutingPeerEnabled,
|
||||
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
|
||||
LazyConnectionEnabled: settings.LazyConnectionEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnabled bool, extraSettings *types.ExtraSettings) *proto.SyncResponse {
|
||||
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings) *proto.SyncResponse {
|
||||
response := &proto.SyncResponse{
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnabled),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
@@ -691,10 +710,11 @@ func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: []string{rPeer.IP.String() + "/32"},
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: rPeer.FQDN(dnsName),
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: []string{rPeer.IP.String() + "/32"},
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: rPeer.FQDN(dnsName),
|
||||
AgentVersion: rPeer.Meta.WtVersion,
|
||||
})
|
||||
}
|
||||
return dst
|
||||
@@ -730,7 +750,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
|
||||
return status.Errorf(codes.Internal, "error handling request")
|
||||
}
|
||||
|
||||
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings.RoutingPeerDNSResolutionEnabled, settings.Extra)
|
||||
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra)
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||
if err != nil {
|
||||
@@ -836,7 +856,7 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
|
||||
return nil, status.Error(codes.NotFound, "no pkce authorization flow information available")
|
||||
}
|
||||
|
||||
flowInfoResp := &proto.PKCEAuthorizationFlow{
|
||||
initInfoFlow := &proto.PKCEAuthorizationFlow{
|
||||
ProviderConfig: &proto.ProviderConfig{
|
||||
Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience,
|
||||
ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID,
|
||||
@@ -847,9 +867,12 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
|
||||
RedirectURLs: s.config.PKCEAuthorizationFlow.ProviderConfig.RedirectURLs,
|
||||
UseIDToken: s.config.PKCEAuthorizationFlow.ProviderConfig.UseIDToken,
|
||||
DisablePromptLogin: s.config.PKCEAuthorizationFlow.ProviderConfig.DisablePromptLogin,
|
||||
LoginFlag: uint32(s.config.PKCEAuthorizationFlow.ProviderConfig.LoginFlag),
|
||||
},
|
||||
}
|
||||
|
||||
flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow)
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")
|
||||
@@ -888,6 +911,45 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
|
||||
start := time.Now()
|
||||
|
||||
empty := &proto.Empty{}
|
||||
peerKey, err := s.parseRequest(ctx, req, empty)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("peer %s is not registered for logout", peerKey.String())
|
||||
// TODO: consider idempotency
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.ID)
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, peer.AccountID)
|
||||
|
||||
userID := peer.UserID
|
||||
if userID == "" {
|
||||
userID = activity.SystemInitiator
|
||||
}
|
||||
|
||||
if err = s.accountManager.DeletePeer(ctx, peer.AccountID, peer.ID, userID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to logout peer %s: %v", peerKey.String(), err)
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
|
||||
|
||||
log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start))
|
||||
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
// toProtocolChecks converts posture checks to protocol checks.
|
||||
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
|
||||
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
package: api
|
||||
generate:
|
||||
models: true
|
||||
embedded-spec: false
|
||||
output: types.gen.go
|
||||
compatibility:
|
||||
always-prefix-enum-values: true
|
||||
@@ -1,16 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
if ! which realpath > /dev/null 2>&1
|
||||
then
|
||||
echo realpath is not installed
|
||||
echo run: brew install coreutils
|
||||
exit 1
|
||||
fi
|
||||
|
||||
old_pwd=$(pwd)
|
||||
script_path=$(dirname $(realpath "$0"))
|
||||
cd "$script_path"
|
||||
go install github.com/deepmap/oapi-codegen/cmd/oapi-codegen@4a1477f6a8ba6ca8115cc23bb2fb67f0b9fca18e
|
||||
oapi-codegen --config cfg.yaml openapi.yml
|
||||
cd "$old_pwd"
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,21 +1,34 @@
|
||||
package accounts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
// PeerBufferPercentage is the percentage of peers to add as buffer for network range calculations
|
||||
PeerBufferPercentage = 0.5
|
||||
// MinRequiredAddresses is the minimum number of addresses required in a network range
|
||||
MinRequiredAddresses = 10
|
||||
// MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16)
|
||||
MinNetworkBitsIPv4 = 28
|
||||
// MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges
|
||||
MinNetworkBitsIPv6 = 120
|
||||
)
|
||||
|
||||
// handler is a handler that handles the server.Account HTTP endpoints
|
||||
type handler struct {
|
||||
accountManager account.Manager
|
||||
@@ -37,6 +50,86 @@ func newHandler(accountManager account.Manager, settingsManager settings.Manager
|
||||
}
|
||||
}
|
||||
|
||||
func validateIPAddress(addr netip.Addr) error {
|
||||
if addr.IsLoopback() {
|
||||
return status.Errorf(status.InvalidArgument, "loopback address range not allowed")
|
||||
}
|
||||
|
||||
if addr.IsMulticast() {
|
||||
return status.Errorf(status.InvalidArgument, "multicast address range not allowed")
|
||||
}
|
||||
|
||||
if addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast() {
|
||||
return status.Errorf(status.InvalidArgument, "link-local address range not allowed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateMinimumSize(prefix netip.Prefix) error {
|
||||
addr := prefix.Addr()
|
||||
if addr.Is4() && prefix.Bits() > MinNetworkBitsIPv4 {
|
||||
return status.Errorf(status.InvalidArgument, "network range too small: minimum size is /%d for IPv4", MinNetworkBitsIPv4)
|
||||
}
|
||||
if addr.Is6() && prefix.Bits() > MinNetworkBitsIPv6 {
|
||||
return status.Errorf(status.InvalidArgument, "network range too small: minimum size is /%d for IPv6", MinNetworkBitsIPv6)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *handler) validateNetworkRange(ctx context.Context, accountID, userID string, networkRange netip.Prefix) error {
|
||||
if !networkRange.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := validateIPAddress(networkRange.Addr()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateMinimumSize(networkRange); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.validateCapacity(ctx, accountID, userID, networkRange)
|
||||
}
|
||||
|
||||
func (h *handler) validateCapacity(ctx context.Context, accountID, userID string, prefix netip.Prefix) error {
|
||||
peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "")
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "get peer count: %v", err)
|
||||
}
|
||||
|
||||
maxHosts := calculateMaxHosts(prefix)
|
||||
requiredAddresses := calculateRequiredAddresses(len(peers))
|
||||
|
||||
if maxHosts < requiredAddresses {
|
||||
return status.Errorf(status.InvalidArgument,
|
||||
"network range too small: need at least %d addresses for %d peers + buffer, but range provides %d",
|
||||
requiredAddresses, len(peers), maxHosts)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func calculateMaxHosts(prefix netip.Prefix) int64 {
|
||||
availableAddresses := prefix.Addr().BitLen() - prefix.Bits()
|
||||
maxHosts := int64(1) << availableAddresses
|
||||
|
||||
if prefix.Addr().Is4() {
|
||||
maxHosts -= 2 // network and broadcast addresses
|
||||
}
|
||||
|
||||
return maxHosts
|
||||
}
|
||||
|
||||
func calculateRequiredAddresses(peerCount int) int64 {
|
||||
requiredAddresses := int64(peerCount) + int64(float64(peerCount)*PeerBufferPercentage)
|
||||
if requiredAddresses < MinRequiredAddresses {
|
||||
requiredAddresses = MinRequiredAddresses
|
||||
}
|
||||
return requiredAddresses
|
||||
}
|
||||
|
||||
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
||||
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
@@ -59,7 +152,13 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(accountID, settings, meta)
|
||||
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(accountID, settings, meta, onboarding)
|
||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||
}
|
||||
|
||||
@@ -122,8 +221,37 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
if req.Settings.DnsDomain != nil {
|
||||
settings.DNSDomain = *req.Settings.DnsDomain
|
||||
}
|
||||
if req.Settings.LazyConnectionEnabled != nil {
|
||||
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
|
||||
}
|
||||
if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" {
|
||||
prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w)
|
||||
return
|
||||
}
|
||||
if err := h.validateNetworkRange(r.Context(), accountID, userID, prefix); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
settings.NetworkRange = prefix
|
||||
}
|
||||
|
||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||
var onboarding *types.AccountOnboarding
|
||||
if req.Onboarding != nil {
|
||||
onboarding = &types.AccountOnboarding{
|
||||
OnboardingFlowPending: req.Onboarding.OnboardingFlowPending,
|
||||
SignupFormPending: req.Onboarding.SignupFormPending,
|
||||
}
|
||||
}
|
||||
|
||||
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -135,7 +263,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings, meta)
|
||||
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
@@ -164,7 +292,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta) *api.Account {
|
||||
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account {
|
||||
jwtAllowGroups := settings.JWTAllowGroups
|
||||
if jwtAllowGroups == nil {
|
||||
jwtAllowGroups = []string{}
|
||||
@@ -181,9 +309,20 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
||||
JwtAllowGroups: &jwtAllowGroups,
|
||||
RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
|
||||
RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled,
|
||||
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
|
||||
DnsDomain: &settings.DNSDomain,
|
||||
}
|
||||
|
||||
if settings.NetworkRange.IsValid() {
|
||||
networkRangeStr := settings.NetworkRange.String()
|
||||
apiSettings.NetworkRange = &networkRangeStr
|
||||
}
|
||||
|
||||
apiOnboarding := api.AccountOnboarding{
|
||||
OnboardingFlowPending: onboarding.OnboardingFlowPending,
|
||||
SignupFormPending: onboarding.SignupFormPending,
|
||||
}
|
||||
|
||||
if settings.Extra != nil {
|
||||
apiSettings.Extra = &api.AccountExtraSettings{
|
||||
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
|
||||
@@ -199,5 +338,6 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
||||
CreatedBy: meta.CreatedBy,
|
||||
Domain: meta.Domain,
|
||||
DomainCategory: meta.DomainCategory,
|
||||
Onboarding: apiOnboarding,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,10 +15,10 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
@@ -36,7 +36,7 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
|
||||
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||
return account.Settings, nil
|
||||
},
|
||||
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) {
|
||||
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||
@@ -46,9 +46,7 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
||||
}
|
||||
|
||||
accCopy := account.Copy()
|
||||
accCopy.UpdateSettings(newSettings)
|
||||
return accCopy, nil
|
||||
return newSettings, nil
|
||||
},
|
||||
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) {
|
||||
return account.Copy(), nil
|
||||
@@ -56,6 +54,18 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
|
||||
GetAccountMetaFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) {
|
||||
return account.GetMeta(), nil
|
||||
},
|
||||
GetAccountOnboardingFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
|
||||
return &types.AccountOnboarding{
|
||||
OnboardingFlowPending: true,
|
||||
SignupFormPending: true,
|
||||
}, nil
|
||||
},
|
||||
UpdateAccountOnboardingFunc: func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
|
||||
return &types.AccountOnboarding{
|
||||
OnboardingFlowPending: true,
|
||||
SignupFormPending: true,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
settingsManager: settingsMockManager,
|
||||
}
|
||||
@@ -108,6 +118,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
JwtAllowGroups: &[]string{},
|
||||
RegularUsersViewBlocked: true,
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
},
|
||||
expectedArray: true,
|
||||
@@ -118,7 +129,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true}}"),
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedSettings: api.AccountSettings{
|
||||
PeerLoginExpiration: 15552000,
|
||||
@@ -129,6 +140,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
JwtAllowGroups: &[]string{},
|
||||
RegularUsersViewBlocked: false,
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
},
|
||||
expectedArray: false,
|
||||
@@ -139,6 +151,50 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedSettings: api.AccountSettings{
|
||||
PeerLoginExpiration: 15552000,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
GroupsPropagationEnabled: br(false),
|
||||
JwtGroupsClaimName: sr("roles"),
|
||||
JwtGroupsEnabled: br(true),
|
||||
JwtAllowGroups: &[]string{"test"},
|
||||
RegularUsersViewBlocked: true,
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
},
|
||||
{
|
||||
name: "PutAccount OK with JWT Propagation",
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedSettings: api.AccountSettings{
|
||||
PeerLoginExpiration: 554400,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
GroupsPropagationEnabled: br(true),
|
||||
JwtGroupsClaimName: sr("groups"),
|
||||
JwtGroupsEnabled: br(true),
|
||||
JwtAllowGroups: &[]string{},
|
||||
RegularUsersViewBlocked: true,
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
},
|
||||
{
|
||||
name: "PutAccount OK without onboarding",
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedSettings: api.AccountSettings{
|
||||
@@ -150,27 +206,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
JwtAllowGroups: &[]string{"test"},
|
||||
RegularUsersViewBlocked: true,
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
},
|
||||
{
|
||||
name: "PutAccount OK with JWT Propagation",
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true}}"),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedSettings: api.AccountSettings{
|
||||
PeerLoginExpiration: 554400,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
GroupsPropagationEnabled: br(true),
|
||||
JwtGroupsClaimName: sr("groups"),
|
||||
JwtGroupsEnabled: br(true),
|
||||
JwtAllowGroups: &[]string{},
|
||||
RegularUsersViewBlocked: true,
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
},
|
||||
expectedArray: false,
|
||||
@@ -181,7 +217,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552001,\"peer_login_expiration_enabled\": true}}"),
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552001,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedArray: false,
|
||||
},
|
||||
@@ -190,7 +226,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 3599,\"peer_login_expiration_enabled\": true}}"),
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 3599,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedArray: false,
|
||||
},
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// nameserversHandler is the nameserver group handler of the account
|
||||
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
// handler HTTP handler
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
@@ -143,7 +143,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
IntegrationReference: existingGroup.IntegrationReference,
|
||||
}
|
||||
|
||||
if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group, false); err != nil {
|
||||
if err := h.accountManager.UpdateGroup(r.Context(), accountID, userID, &group); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -203,7 +203,7 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
|
||||
Issued: types.GroupIssuedAPI,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group, true)
|
||||
err = h.accountManager.CreateGroup(r.Context(), accountID, userID, &group)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -19,11 +19,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
|
||||
@@ -12,14 +12,14 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
)
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
)
|
||||
@@ -19,7 +19,8 @@ type routersHandler struct {
|
||||
|
||||
func addRouterEndpoints(routersManager routers.Manager, router *mux.Router) {
|
||||
routersHandler := newRoutersHandler(routersManager)
|
||||
router.HandleFunc("/networks/{networkId}/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers", routersHandler.getNetworkRouters).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS")
|
||||
@@ -41,6 +42,31 @@ func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routersResponse := make([]*api.NetworkRouter, 0)
|
||||
for _, routers := range routersMap {
|
||||
for _, router := range routers {
|
||||
routersResponse = append(routersResponse, router.ToAPIResponse())
|
||||
}
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, routersResponse)
|
||||
}
|
||||
|
||||
func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -13,10 +14,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
@@ -111,6 +112,19 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
|
||||
}
|
||||
}
|
||||
|
||||
if req.Ip != nil {
|
||||
addr, err := netip.ParseAddr(*req.Ip)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.accountManager.UpdatePeerIP(ctx, accountID, userID, peerID, addr); err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
@@ -365,6 +379,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
|
||||
CityName: peer.Location.CityName,
|
||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||
Ephemeral: peer.Ephemeral,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -16,11 +17,12 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
@@ -112,6 +114,15 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
|
||||
p.Name = update.Name
|
||||
return p, nil
|
||||
},
|
||||
UpdatePeerIPFunc: func(_ context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
|
||||
for _, peer := range peers {
|
||||
if peer.ID == peerID {
|
||||
peer.IP = net.IP(newIP.AsSlice())
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("peer not found")
|
||||
},
|
||||
GetPeerFunc: func(_ context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
var p *nbpeer.Peer
|
||||
for _, peer := range peers {
|
||||
@@ -450,3 +461,73 @@ func TestGetAccessiblePeers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeersHandlerUpdatePeerIP(t *testing.T) {
|
||||
testPeer := &nbpeer.Peer{
|
||||
ID: testPeerID,
|
||||
Key: "key",
|
||||
IP: net.ParseIP("100.64.0.1"),
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
Name: "test-host@netbird.io",
|
||||
LoginExpirationEnabled: false,
|
||||
UserID: regularUser,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-host@netbird.io",
|
||||
Core: "22.04",
|
||||
},
|
||||
}
|
||||
|
||||
p := initTestMetaData(testPeer)
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
peerID string
|
||||
requestBody string
|
||||
callerUserID string
|
||||
expectedStatus int
|
||||
expectedIP string
|
||||
}{
|
||||
{
|
||||
name: "update peer IP successfully",
|
||||
peerID: testPeerID,
|
||||
requestBody: `{"ip": "100.64.0.100"}`,
|
||||
callerUserID: adminUser,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedIP: "100.64.0.100",
|
||||
},
|
||||
{
|
||||
name: "update peer IP with invalid IP",
|
||||
peerID: testPeerID,
|
||||
requestBody: `{"ip": "invalid-ip"}`,
|
||||
callerUserID: adminUser,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/peers/%s", tc.peerID), bytes.NewBuffer([]byte(tc.requestBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: tc.callerUserID,
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/peers/{peerId}", p.HandlePeer).Methods("PUT")
|
||||
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedStatus == http.StatusOK && tc.expectedIP != "" {
|
||||
var updatedPeer api.Peer
|
||||
err := json.Unmarshal(rr.Body.Bytes(), &updatedPeer)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedIP, updatedPeer.Ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
|
||||
@@ -9,12 +9,12 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
@@ -255,23 +255,12 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
}
|
||||
|
||||
// validate policy object
|
||||
switch pr.Protocol {
|
||||
case types.PolicyRuleProtocolALL, types.PolicyRuleProtocolICMP:
|
||||
if pr.Protocol == types.PolicyRuleProtocolALL || pr.Protocol == types.PolicyRuleProtocolICMP {
|
||||
if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
|
||||
return
|
||||
}
|
||||
if !pr.Bidirectional {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
|
||||
return
|
||||
}
|
||||
case types.PolicyRuleProtocolTCP, types.PolicyRuleProtocolUDP:
|
||||
if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0) {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
policy.Rules = append(policy.Rules, &pr)
|
||||
}
|
||||
|
||||
@@ -435,9 +424,10 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
|
||||
}
|
||||
if group, ok := groupsMap[gid]; ok {
|
||||
minimum := api.GroupMinimum{
|
||||
Id: group.ID,
|
||||
Name: group.Name,
|
||||
PeersCount: len(group.Peers),
|
||||
Id: group.ID,
|
||||
Name: group.Name,
|
||||
PeersCount: len(group.Peers),
|
||||
ResourcesCount: len(group.Resources),
|
||||
}
|
||||
destinations = append(destinations, minimum)
|
||||
cache[gid] = minimum
|
||||
|
||||
@@ -14,9 +14,9 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
|
||||
@@ -9,10 +9,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// postureChecksHandler is a handler that returns posture checks of the account.
|
||||
|
||||
@@ -16,10 +16,10 @@ import (
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
var berlin = "Berlin"
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
|
||||
@@ -15,11 +15,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
|
||||
@@ -15,9 +15,9 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
|
||||
@@ -8,9 +8,9 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
|
||||
|
||||
@@ -16,11 +16,11 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
)
|
||||
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
)
|
||||
|
||||
@@ -133,12 +133,12 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
|
||||
}
|
||||
|
||||
geoMock := &geolocation.Mock{}
|
||||
validatorMock := server.MocIntegratedValidator{}
|
||||
validatorMock := server.MockIntegratedValidator{}
|
||||
proxyController := integrations.NewController(store)
|
||||
userManager := users.NewManager(store)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
|
||||
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager)
|
||||
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
// EmptyObject is an empty struct used to return empty JSON object
|
||||
type EmptyObject struct {
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
}
|
||||
|
||||
// WriteJSONObject simply writes object to the HTTP response in JSON format
|
||||
func WriteJSONObject(ctx context.Context, w http.ResponseWriter, obj interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
err := json.NewEncoder(w).Encode(obj)
|
||||
if err != nil {
|
||||
WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Duration is used strictly for JSON requests/responses due to duration marshalling issues
|
||||
type Duration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
// MarshalJSON marshals the duration
|
||||
func (d Duration) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(d.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals the duration
|
||||
func (d *Duration) UnmarshalJSON(b []byte) error {
|
||||
var v interface{}
|
||||
if err := json.Unmarshal(b, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
switch value := v.(type) {
|
||||
case float64:
|
||||
d.Duration = time.Duration(value)
|
||||
return nil
|
||||
case string:
|
||||
var err error
|
||||
d.Duration, err = time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return errors.New("invalid duration")
|
||||
}
|
||||
}
|
||||
|
||||
// WriteErrorResponse prepares and writes an error response i nJSON
|
||||
func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
w.WriteHeader(httpStatus)
|
||||
err := json.NewEncoder(w).Encode(&ErrorResponse{
|
||||
Message: errMsg,
|
||||
Code: httpStatus,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, "failed handling request", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteError converts an error to an JSON error response.
|
||||
// If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise
|
||||
func WriteError(ctx context.Context, err error, w http.ResponseWriter) {
|
||||
log.WithContext(ctx).Errorf("got a handler error: %s", err.Error())
|
||||
errStatus, ok := status.FromError(err)
|
||||
httpStatus := http.StatusInternalServerError
|
||||
msg := "internal server error"
|
||||
if ok {
|
||||
switch errStatus.Type() {
|
||||
case status.UserAlreadyExists:
|
||||
httpStatus = http.StatusConflict
|
||||
case status.AlreadyExists:
|
||||
httpStatus = http.StatusConflict
|
||||
case status.PreconditionFailed:
|
||||
httpStatus = http.StatusPreconditionFailed
|
||||
case status.PermissionDenied:
|
||||
httpStatus = http.StatusForbidden
|
||||
case status.NotFound:
|
||||
httpStatus = http.StatusNotFound
|
||||
case status.Internal:
|
||||
httpStatus = http.StatusInternalServerError
|
||||
case status.InvalidArgument:
|
||||
httpStatus = http.StatusUnprocessableEntity
|
||||
case status.Unauthorized:
|
||||
httpStatus = http.StatusUnauthorized
|
||||
case status.BadRequest:
|
||||
httpStatus = http.StatusBadRequest
|
||||
default:
|
||||
}
|
||||
msg = strings.ToLower(err.Error())
|
||||
} else {
|
||||
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error())
|
||||
log.WithContext(ctx).Error(unhandledMSG)
|
||||
}
|
||||
|
||||
WriteErrorResponse(msg, httpStatus, w)
|
||||
}
|
||||
@@ -3,55 +3,71 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account.
|
||||
// UpdateIntegratedValidator updates the integrated validator groups for a specified account.
|
||||
// It retrieves the account associated with the provided userID, then updates the integrated validator groups
|
||||
// with the provided list of group ids. The updated account is then saved.
|
||||
//
|
||||
// Parameters:
|
||||
// - accountID: The ID of the account for which integrated validator groups are to be updated.
|
||||
// - userID: The ID of the user whose account is being updated.
|
||||
// - validator: The validator type to use, or empty to remove.
|
||||
// - groups: A slice of strings representing the ids of integrated validator groups to be updated.
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if any occurred during the process, otherwise returns nil
|
||||
func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error {
|
||||
ok, err := am.GroupValidation(ctx, accountID, groups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error validating groups: %s", err.Error())
|
||||
return err
|
||||
func (am *DefaultAccountManager) UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error {
|
||||
if validator != "" && len(groups) == 0 {
|
||||
return fmt.Errorf("at least one group must be specified for validator")
|
||||
}
|
||||
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("invalid groups")
|
||||
return errors.New("invalid groups")
|
||||
if validator != "" {
|
||||
ok, err := am.GroupValidation(ctx, accountID, groups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error validating groups: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("invalid groups")
|
||||
return errors.New("invalid groups")
|
||||
}
|
||||
} else {
|
||||
// ensure groups is empty
|
||||
groups = []string{}
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
a, err := am.Store.GetAccountByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
a, err := transaction.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var extra *types.ExtraSettings
|
||||
var extra *types.ExtraSettings
|
||||
|
||||
if a.Settings.Extra != nil {
|
||||
extra = a.Settings.Extra
|
||||
} else {
|
||||
extra = &types.ExtraSettings{}
|
||||
a.Settings.Extra = extra
|
||||
}
|
||||
extra.IntegratedValidatorGroups = groups
|
||||
return am.Store.SaveAccount(ctx, a)
|
||||
if a.Settings.Extra != nil {
|
||||
extra = a.Settings.Extra
|
||||
} else {
|
||||
extra = &types.ExtraSettings{}
|
||||
a.Settings.Extra = extra
|
||||
}
|
||||
|
||||
extra.IntegratedValidator = validator
|
||||
extra.IntegratedValidatorGroups = groups
|
||||
return transaction.SaveAccount(ctx, a)
|
||||
})
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
|
||||
@@ -61,7 +77,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
_, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthShare, accountID, groupID)
|
||||
_, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -81,43 +97,41 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
|
||||
var peers []*nbpeer.Peer
|
||||
var settings *types.Settings
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peers, err = transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
|
||||
return err
|
||||
})
|
||||
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra)
|
||||
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra)
|
||||
}
|
||||
|
||||
type MocIntegratedValidator struct {
|
||||
type MockIntegratedValidator struct {
|
||||
integrated_validator.IntegratedValidator
|
||||
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
||||
func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||
func (a MockIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||
if a.ValidatePeerFunc != nil {
|
||||
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
|
||||
}
|
||||
return update, false, nil
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) {
|
||||
func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) {
|
||||
validatedPeers := make(map[string]struct{})
|
||||
for _, peer := range peers {
|
||||
validatedPeers[peer.ID] = struct{}{}
|
||||
@@ -125,22 +139,22 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*ty
|
||||
return validatedPeers, nil
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer {
|
||||
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer {
|
||||
return peer
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) {
|
||||
func (MockIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) {
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
|
||||
func (MockIntegratedValidator) PeerDeleted(_ context.Context, _, _ string, extraSettings *types.ExtraSettings) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
|
||||
func (MockIntegratedValidator) SetPeerInvalidationListener(func(accountID string, peerIDs []string)) {
|
||||
// just a dummy
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) Stop(_ context.Context) {
|
||||
func (MockIntegratedValidator) Stop(_ context.Context) {
|
||||
// just a dummy
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package integrated_validator
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
@@ -13,8 +14,9 @@ type IntegratedValidator interface {
|
||||
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
|
||||
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer
|
||||
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
|
||||
GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
|
||||
PeerDeleted(ctx context.Context, accountID, peerID string) error
|
||||
SetPeerInvalidationListener(fn func(accountID string))
|
||||
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
|
||||
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error
|
||||
SetPeerInvalidationListener(fn func(accountID string, peerIDs []string))
|
||||
Stop(ctx context.Context)
|
||||
ValidateFlowResponse(ctx context.Context, peerKey string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
@@ -30,6 +29,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -440,11 +440,15 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
AnyTimes().
|
||||
Return(&types.Settings{}, nil)
|
||||
|
||||
settingsMockManager.
|
||||
EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
|
||||
if err != nil {
|
||||
cleanup()
|
||||
@@ -454,7 +458,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
|
||||
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||
|
||||
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
||||
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil)
|
||||
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
@@ -641,7 +645,7 @@ func testSyncStatusRace(t *testing.T) {
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerWithInvalidStatus.PublicKey().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
@@ -206,12 +206,12 @@ func startServer(
|
||||
eventStore,
|
||||
nil,
|
||||
false,
|
||||
server.MocIntegratedValidator{},
|
||||
server.MockIntegratedValidator{},
|
||||
metrics,
|
||||
port_forwarding.NewControllerMock(),
|
||||
settingsMockManager,
|
||||
permissionsManager,
|
||||
)
|
||||
false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed creating an account manager: %v", err)
|
||||
}
|
||||
@@ -227,6 +227,7 @@ func startServer(
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
server.MockIntegratedValidator{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed creating management server: %v", err)
|
||||
|
||||
@@ -184,7 +184,9 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
ephemeralPeersSKs int
|
||||
ephemeralPeersSKUsage int
|
||||
activePeersLastDay int
|
||||
activeUserPeersLastDay int
|
||||
osPeers map[string]int
|
||||
activeUsersLastDay map[string]struct{}
|
||||
userPeers int
|
||||
rules int
|
||||
rulesProtocol map[string]int
|
||||
@@ -203,6 +205,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
version string
|
||||
peerActiveVersions []string
|
||||
osUIClients map[string]int
|
||||
rosenpassEnabled int
|
||||
)
|
||||
start := time.Now()
|
||||
metricsProperties := make(properties)
|
||||
@@ -210,6 +213,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
osUIClients = make(map[string]int)
|
||||
rulesProtocol = make(map[string]int)
|
||||
rulesDirection = make(map[string]int)
|
||||
activeUsersLastDay = make(map[string]struct{})
|
||||
uptime = time.Since(w.startupTime).Seconds()
|
||||
connections := w.connManager.GetAllConnectedPeers()
|
||||
version = nbversion.NetbirdVersion()
|
||||
@@ -277,10 +281,14 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
for _, peer := range account.Peers {
|
||||
peers++
|
||||
|
||||
if peer.SSHEnabled {
|
||||
if peer.SSHEnabled || peer.Meta.Flags.ServerSSHAllowed {
|
||||
peersSSHEnabled++
|
||||
}
|
||||
|
||||
if peer.Meta.Flags.RosenpassEnabled {
|
||||
rosenpassEnabled++
|
||||
}
|
||||
|
||||
if peer.UserID != "" {
|
||||
userPeers++
|
||||
}
|
||||
@@ -299,6 +307,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
_, connected := connections[peer.ID]
|
||||
if connected || peer.Status.LastSeen.After(w.lastRun) {
|
||||
activePeersLastDay++
|
||||
if peer.UserID != "" {
|
||||
activeUserPeersLastDay++
|
||||
activeUsersLastDay[peer.UserID] = struct{}{}
|
||||
}
|
||||
osActiveKey := osKey + "_active"
|
||||
osActiveCount := osPeers[osActiveKey]
|
||||
osPeers[osActiveKey] = osActiveCount + 1
|
||||
@@ -320,6 +332,8 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
metricsProperties["ephemeral_peers_setup_keys"] = ephemeralPeersSKs
|
||||
metricsProperties["ephemeral_peers_setup_keys_usage"] = ephemeralPeersSKUsage
|
||||
metricsProperties["active_peers_last_day"] = activePeersLastDay
|
||||
metricsProperties["active_user_peers_last_day"] = activeUserPeersLastDay
|
||||
metricsProperties["active_users_last_day"] = len(activeUsersLastDay)
|
||||
metricsProperties["user_peers"] = userPeers
|
||||
metricsProperties["rules"] = rules
|
||||
metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks
|
||||
@@ -338,6 +352,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
metricsProperties["ui_clients"] = uiClient
|
||||
metricsProperties["idp_manager"] = w.idpManager
|
||||
metricsProperties["store_engine"] = w.dataSource.GetStoreEngine()
|
||||
metricsProperties["rosenpass_enabled"] = rosenpassEnabled
|
||||
|
||||
for protocol, count := range rulesProtocol {
|
||||
metricsProperties["rules_protocol_"+protocol] = count
|
||||
|
||||
@@ -47,8 +47,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
"1": {
|
||||
ID: "1",
|
||||
UserID: "test",
|
||||
SSHEnabled: true,
|
||||
Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"},
|
||||
SSHEnabled: false,
|
||||
Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1", Flags: nbpeer.Flags{ServerSSHAllowed: true, RosenpassEnabled: true}},
|
||||
},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
@@ -312,7 +312,19 @@ func TestGenerateProperties(t *testing.T) {
|
||||
}
|
||||
|
||||
if properties["posture_checks"] != 2 {
|
||||
t.Errorf("expected 1 posture_checks, got %d", properties["posture_checks"])
|
||||
t.Errorf("expected 2 posture_checks, got %d", properties["posture_checks"])
|
||||
}
|
||||
|
||||
if properties["rosenpass_enabled"] != 1 {
|
||||
t.Errorf("expected 1 rosenpass_enabled, got %d", properties["rosenpass_enabled"])
|
||||
}
|
||||
|
||||
if properties["active_user_peers_last_day"] != 2 {
|
||||
t.Errorf("expected 2 active_user_peers_last_day, got %d", properties["active_user_peers_last_day"])
|
||||
}
|
||||
|
||||
if properties["active_users_last_day"] != 1 {
|
||||
t.Errorf("expected 1 active_users_last_day, got %d", properties["active_users_last_day"])
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func GetColumnName(db *gorm.DB, column string) string {
|
||||
@@ -39,6 +40,11 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f
|
||||
return nil
|
||||
}
|
||||
|
||||
if !db.Migrator().HasColumn(&model, fieldName) {
|
||||
log.WithContext(ctx).Debugf("Table for %T does not have column %s, no migration needed", model, fieldName)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(model)
|
||||
if err != nil {
|
||||
@@ -283,7 +289,7 @@ func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) er
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", "peers", "setup_key")).Error; err != nil {
|
||||
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s", "peers", "setup_key")).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err)
|
||||
}
|
||||
|
||||
@@ -373,3 +379,111 @@ func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error
|
||||
log.WithContext(ctx).Infof("dropped index %s from table %T", indexName, model)
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error {
|
||||
var model T
|
||||
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
if err := stmt.Parse(&model); err != nil {
|
||||
return fmt.Errorf("failed to parse model schema: %w", err)
|
||||
}
|
||||
tableName := stmt.Schema.Table
|
||||
dialect := db.Dialector.Name()
|
||||
|
||||
if db.Migrator().HasIndex(&model, indexName) {
|
||||
log.WithContext(ctx).Infof("index %s already exists on table %s", indexName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
var columnClause string
|
||||
if dialect == "mysql" {
|
||||
var withLength []string
|
||||
for _, col := range columns {
|
||||
if col == "ip" || col == "dns_label" {
|
||||
withLength = append(withLength, fmt.Sprintf("%s(64)", col))
|
||||
} else {
|
||||
withLength = append(withLength, col)
|
||||
}
|
||||
}
|
||||
columnClause = strings.Join(withLength, ", ")
|
||||
} else {
|
||||
columnClause = strings.Join(columns, ", ")
|
||||
}
|
||||
|
||||
createStmt := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, tableName, columnClause)
|
||||
if dialect == "postgres" || dialect == "sqlite" {
|
||||
createStmt = strings.Replace(createStmt, "CREATE UNIQUE INDEX", "CREATE UNIQUE INDEX IF NOT EXISTS", 1)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("executing index creation: %s", createStmt)
|
||||
if err := db.Exec(createStmt).Error; err != nil {
|
||||
return fmt.Errorf("failed to create index %s: %w", indexName, err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("successfully created index %s on table %s", indexName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName string, mapperFunc func(accountID string, id string, value string) any) error {
|
||||
var model T
|
||||
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(&model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
tableName := stmt.Schema.Table
|
||||
|
||||
if !db.Migrator().HasColumn(&model, columnName) {
|
||||
log.WithContext(ctx).Debugf("column %s does not exist in table %s, no migration needed", columnName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||
var rows []map[string]any
|
||||
if err := tx.Table(tableName).Select("id", "account_id", columnName).Find(&rows).Error; err != nil {
|
||||
return fmt.Errorf("find rows: %w", err)
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
jsonValue, ok := row[columnName].(string)
|
||||
if !ok || jsonValue == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var data []string
|
||||
if err := json.Unmarshal([]byte(jsonValue), &data); err != nil {
|
||||
return fmt.Errorf("unmarshal json: %w", err)
|
||||
}
|
||||
|
||||
for _, value := range data {
|
||||
if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(
|
||||
mapperFunc(row["account_id"].(string), row["id"].(string), value),
|
||||
).Error; err != nil {
|
||||
return fmt.Errorf("failed to insert id %v: %w", row["id"], err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Migrator().DropColumn(&model, columnName); err != nil {
|
||||
return fmt.Errorf("drop column %s: %w", columnName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into separate table completed", columnName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,16 +4,21 @@ import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/testutil"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -21,7 +26,41 @@ import (
|
||||
func setupDatabase(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||
var db *gorm.DB
|
||||
var err error
|
||||
var dsn string
|
||||
var cleanup func()
|
||||
switch os.Getenv("NETBIRD_STORE_ENGINE") {
|
||||
case "mysql":
|
||||
cleanup, dsn, err = testutil.CreateMysqlTestContainer()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create MySQL test container: %v", err)
|
||||
}
|
||||
|
||||
if dsn == "" {
|
||||
t.Fatal("MySQL connection string is empty, ensure the test container is running")
|
||||
}
|
||||
|
||||
db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
|
||||
case "postgres":
|
||||
cleanup, dsn, err = testutil.CreatePostgresTestContainer()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create PostgreSQL test container: %v", err)
|
||||
}
|
||||
|
||||
if dsn == "" {
|
||||
t.Fatalf("PostgreSQL connection string is empty, ensure the test container is running")
|
||||
}
|
||||
|
||||
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
case "sqlite":
|
||||
db, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||
default:
|
||||
db, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||
}
|
||||
if cleanup != nil {
|
||||
t.Cleanup(cleanup)
|
||||
}
|
||||
|
||||
require.NoError(t, err, "Failed to open database")
|
||||
return db
|
||||
@@ -34,6 +73,7 @@ func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&types.Account{}, &route.Route{})
|
||||
@@ -97,6 +137,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
|
||||
@@ -117,12 +158,18 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
|
||||
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
|
||||
}
|
||||
|
||||
err = db.Save(&account{
|
||||
a := &account{
|
||||
Account: types.Account{Id: "123"},
|
||||
Peers: []peer{
|
||||
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||
}},
|
||||
).Error
|
||||
}
|
||||
|
||||
err = db.Save(a).Error
|
||||
require.NoError(t, err, "Failed to insert account")
|
||||
|
||||
a.Peers = []peer{
|
||||
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||
}
|
||||
|
||||
err = db.Save(a).Error
|
||||
require.NoError(t, err, "Failed to insert blob data")
|
||||
|
||||
var blobValue string
|
||||
@@ -143,12 +190,18 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
|
||||
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
err = db.Save(&types.Account{
|
||||
account := &types.Account{
|
||||
Id: "1234",
|
||||
PeersG: []nbpeer.Peer{
|
||||
{Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||
}},
|
||||
).Error
|
||||
}
|
||||
|
||||
err = db.Save(account).Error
|
||||
require.NoError(t, err, "Failed to insert account")
|
||||
|
||||
account.PeersG = []nbpeer.Peer{
|
||||
{AccountID: "1234", Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||
}
|
||||
|
||||
err = db.Save(account).Error
|
||||
require.NoError(t, err, "Failed to insert JSON data")
|
||||
|
||||
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "")
|
||||
@@ -162,12 +215,13 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
|
||||
func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&types.SetupKey{})
|
||||
err := db.AutoMigrate(&types.SetupKey{}, &nbpeer.Peer{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
err = db.Save(&types.SetupKey{
|
||||
Id: "1",
|
||||
Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382",
|
||||
Id: "1",
|
||||
Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382",
|
||||
UpdatedAt: time.Now(),
|
||||
}).Error
|
||||
require.NoError(t, err, "Failed to insert setup key")
|
||||
|
||||
@@ -192,6 +246,7 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.
|
||||
Id: "1",
|
||||
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
||||
KeySecret: "EEFDA****",
|
||||
UpdatedAt: time.Now(),
|
||||
}).Error
|
||||
require.NoError(t, err, "Failed to insert setup key")
|
||||
|
||||
@@ -213,8 +268,9 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
err = db.Save(&types.SetupKey{
|
||||
Id: "1",
|
||||
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
||||
Id: "1",
|
||||
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
||||
UpdatedAt: time.Now(),
|
||||
}).Error
|
||||
require.NoError(t, err, "Failed to insert setup key")
|
||||
|
||||
@@ -235,8 +291,9 @@ func TestDropIndex(t *testing.T) {
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
err = db.Save(&types.SetupKey{
|
||||
Id: "1",
|
||||
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
||||
Id: "1",
|
||||
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
||||
UpdatedAt: time.Now(),
|
||||
}).Error
|
||||
require.NoError(t, err, "Failed to insert setup key")
|
||||
|
||||
@@ -249,3 +306,37 @@ func TestDropIndex(t *testing.T) {
|
||||
exist = db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id")
|
||||
assert.False(t, exist, "Should not have the index")
|
||||
}
|
||||
|
||||
func TestCreateIndex(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
err := db.AutoMigrate(&nbpeer.Peer{})
|
||||
assert.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
indexName := "idx_account_ip"
|
||||
|
||||
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
|
||||
assert.NoError(t, err, "Migration should not fail to create index")
|
||||
|
||||
exist := db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
|
||||
assert.True(t, exist, "Should have the index")
|
||||
}
|
||||
|
||||
func TestCreateIndexIfExists(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
err := db.AutoMigrate(&nbpeer.Peer{})
|
||||
assert.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
indexName := "idx_account_ip"
|
||||
|
||||
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
|
||||
assert.NoError(t, err, "Migration should not fail to create index")
|
||||
|
||||
exist := db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
|
||||
assert.True(t, exist, "Should have the index")
|
||||
|
||||
err = migration.CreateIndexIfNotExists[nbpeer.Peer](context.Background(), db, indexName, "account_id", "ip")
|
||||
assert.NoError(t, err, "Create index should not fail if index exists")
|
||||
|
||||
exist = db.Migrator().HasIndex(&nbpeer.Peer{}, indexName)
|
||||
assert.True(t, exist, "Should have the index")
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
@@ -30,98 +30,139 @@ type MockAccountManager struct {
|
||||
GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error)
|
||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType types.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
|
||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
|
||||
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
|
||||
AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error)
|
||||
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
|
||||
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error
|
||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*types.Group, error)
|
||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error)
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error)
|
||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
|
||||
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
||||
ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error)
|
||||
SaveSetupKeyFunc func(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
|
||||
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
|
||||
SaveUserFunc func(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error)
|
||||
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error)
|
||||
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error)
|
||||
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error)
|
||||
GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error)
|
||||
GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
||||
GetDNSDomainFunc func(settings *types.Settings) string
|
||||
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||
GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
|
||||
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error
|
||||
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error)
|
||||
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
||||
HasConnectedChannelFunc func(peerID string) bool
|
||||
GetExternalCacheManagerFunc func() account.ExternalCacheManager
|
||||
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error)
|
||||
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
UpdateIntegratedValidatorGroupsFunc func(ctx context.Context, accountID string, userID string, groups []string) error
|
||||
GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
||||
GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error)
|
||||
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
||||
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
GetStoreFunc func() store.Store
|
||||
CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error)
|
||||
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
|
||||
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
|
||||
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
|
||||
AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error)
|
||||
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
|
||||
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error
|
||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*types.Group, error)
|
||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error)
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error)
|
||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
UpdatePeerIPFunc func(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
|
||||
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
||||
ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error)
|
||||
SaveSetupKeyFunc func(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
|
||||
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
|
||||
SaveUserFunc func(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error)
|
||||
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error)
|
||||
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error)
|
||||
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error)
|
||||
GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error)
|
||||
GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
||||
GetDNSDomainFunc func(settings *types.Settings) string
|
||||
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||
GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
|
||||
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error
|
||||
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
||||
HasConnectedChannelFunc func(peerID string) bool
|
||||
GetExternalCacheManagerFunc func() account.ExternalCacheManager
|
||||
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error)
|
||||
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
UpdateIntegratedValidatorFunc func(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||
GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
||||
GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error)
|
||||
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
||||
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
GetStoreFunc func() store.Store
|
||||
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
|
||||
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
||||
GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error)
|
||||
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
|
||||
if am.SaveGroupFunc != nil {
|
||||
return am.SaveGroupFunc(ctx, accountID, userID, group, true)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method CreateGroup is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
|
||||
if am.SaveGroupFunc != nil {
|
||||
return am.SaveGroupFunc(ctx, accountID, userID, group, false)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method UpdateGroup is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error {
|
||||
if am.SaveGroupsFunc != nil {
|
||||
return am.SaveGroupsFunc(ctx, accountID, userID, newGroups, true)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method CreateGroups is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error {
|
||||
if am.SaveGroupsFunc != nil {
|
||||
return am.SaveGroupsFunc(ctx, accountID, userID, newGroups, false)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method UpdateGroups is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
// do nothing
|
||||
if am.UpdateAccountPeersFunc != nil {
|
||||
am.UpdateAccountPeersFunc(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
if am.BufferUpdateAccountPeersFunc != nil {
|
||||
am.BufferUpdateAccountPeersFunc(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
|
||||
@@ -443,6 +484,13 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID
|
||||
return nil, status.Errorf(codes.Unimplemented, "method UpdatePeer is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
|
||||
if am.UpdatePeerIPFunc != nil {
|
||||
return am.UpdatePeerIPFunc(ctx, accountID, userID, peerID, newIP)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method UpdatePeerIP is not implemented")
|
||||
}
|
||||
|
||||
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
|
||||
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
||||
if am.CreateRouteFunc != nil {
|
||||
@@ -661,7 +709,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us
|
||||
}
|
||||
|
||||
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) {
|
||||
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
|
||||
if am.UpdateAccountSettingsFunc != nil {
|
||||
return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings)
|
||||
}
|
||||
@@ -757,10 +805,10 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateIntegratedValidatorGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error {
|
||||
if am.UpdateIntegratedValidatorGroupsFunc != nil {
|
||||
return am.UpdateIntegratedValidatorGroupsFunc(ctx, accountID, userID, groups)
|
||||
// UpdateIntegratedValidator mocks UpdateIntegratedApprovalGroups of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error {
|
||||
if am.UpdateIntegratedValidatorFunc != nil {
|
||||
return am.UpdateIntegratedValidatorFunc(ctx, accountID, userID, validator, groups)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method UpdateIntegratedValidatorGroups is not implemented")
|
||||
}
|
||||
@@ -813,6 +861,22 @@ func (am *MockAccountManager) GetAccountMeta(ctx context.Context, accountID stri
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountMeta is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountOnboarding mocks GetAccountOnboarding of the AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
|
||||
if am.GetAccountOnboardingFunc != nil {
|
||||
return am.GetAccountOnboardingFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountOnboarding is not implemented")
|
||||
}
|
||||
|
||||
// UpdateAccountOnboarding mocks UpdateAccountOnboarding of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID string, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
|
||||
if am.UpdateAccountOnboardingFunc != nil {
|
||||
return am.UpdateAccountOnboardingFunc(ctx, accountID, userID, onboarding)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method UpdateAccountOnboarding is not implemented")
|
||||
}
|
||||
|
||||
// GetUserByID mocks GetUserByID of the AccountManager interface
|
||||
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
||||
if am.GetUserByIDFunc != nil {
|
||||
@@ -862,11 +926,11 @@ func (am *MockAccountManager) GetStore() store.Store {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) {
|
||||
if am.CreateAccountByPrivateDomainFunc != nil {
|
||||
return am.CreateAccountByPrivateDomainFunc(ctx, initiatorId, domain)
|
||||
func (am *MockAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) {
|
||||
if am.GetOrCreateAccountByPrivateDomainFunc != nil {
|
||||
return am.GetOrCreateAccountByPrivateDomainFunc(ctx, initiatorId, domain)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateAccountByPrivateDomain is not implemented")
|
||||
return nil, false, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByPrivateDomainFunc is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type ManagementServiceServerMock struct {
|
||||
|
||||
@@ -13,12 +13,14 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
|
||||
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$`
|
||||
|
||||
var invalidDomainName = errors.New("invalid domain name")
|
||||
|
||||
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
@@ -30,7 +32,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID)
|
||||
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupID)
|
||||
}
|
||||
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
@@ -71,11 +73,11 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, newNSGroup)
|
||||
return transaction.SaveNameServerGroup(ctx, newNSGroup)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -110,7 +112,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupToSave.ID)
|
||||
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -125,11 +127,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, nsGroupToSave)
|
||||
return transaction.SaveNameServerGroup(ctx, nsGroupToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -171,11 +173,11 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeleteNameServerGroup(ctx, store.LockingStrengthUpdate, accountID, nsGroupID)
|
||||
return transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -200,7 +202,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
|
||||
@@ -214,7 +216,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
|
||||
return err
|
||||
}
|
||||
|
||||
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -224,7 +226,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
|
||||
return err
|
||||
}
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, nameserverGroup.Groups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, nameserverGroup.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -319,13 +321,9 @@ func validateDomain(domain string) error {
|
||||
return errors.New("domain should consists of only letters, numbers, and hyphens with no leading, trailing hyphens, or spaces")
|
||||
}
|
||||
|
||||
labels, valid := dns.IsDomainName(domain)
|
||||
_, valid := dns.IsDomainName(domain)
|
||||
if !valid {
|
||||
return errors.New("invalid domain name")
|
||||
}
|
||||
|
||||
if labels < 2 {
|
||||
return errors.New("domain should consists of a minimum of two labels")
|
||||
return invalidDomainName
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -778,8 +778,14 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.
|
||||
EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
}
|
||||
|
||||
func createNSStore(t *testing.T) (store.Store, error) {
|
||||
@@ -848,7 +854,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
|
||||
userID := testUserID
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
|
||||
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup
|
||||
|
||||
@@ -899,13 +905,33 @@ func TestValidateDomain(t *testing.T) {
|
||||
errFunc: require.NoError,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain name with double hyphen",
|
||||
domain: "test--example.com",
|
||||
name: "Valid domain name with only one label",
|
||||
domain: "example",
|
||||
errFunc: require.NoError,
|
||||
},
|
||||
{
|
||||
name: "Valid domain name with trailing dot",
|
||||
domain: "example.",
|
||||
errFunc: require.NoError,
|
||||
},
|
||||
{
|
||||
name: "Invalid wildcard domain name",
|
||||
domain: "*.example",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain name with only one label",
|
||||
domain: "com",
|
||||
name: "Invalid domain name with leading dot",
|
||||
domain: ".com",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain name with dot only",
|
||||
domain: ".",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain name with double hyphen",
|
||||
domain: "test--example.com",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
@@ -954,18 +980,18 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
var newNameServerGroupA *nbdns.NameServerGroup
|
||||
var newNameServerGroupB *nbdns.NameServerGroup
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
},
|
||||
}, true)
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
|
||||
@@ -14,8 +14,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
@@ -56,7 +56,7 @@ func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID stri
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetAccountNetworks(ctx, store.LockingStrengthShare, accountID)
|
||||
return m.store.GetAccountNetworks(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
|
||||
@@ -73,7 +73,7 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
|
||||
unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID)
|
||||
defer unlock()
|
||||
|
||||
err = m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network)
|
||||
err = m.store.SaveNetwork(ctx, network)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save network: %w", err)
|
||||
}
|
||||
@@ -92,7 +92,7 @@ func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, network
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID)
|
||||
return m.store.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
|
||||
@@ -114,7 +114,7 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta())
|
||||
|
||||
return network, m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network)
|
||||
return network, m.store.SaveNetwork(ctx, network)
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
|
||||
@@ -162,12 +162,12 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
||||
eventsToStore = append(eventsToStore, event)
|
||||
}
|
||||
|
||||
err = transaction.DeleteNetwork(ctx, store.LockingStrengthUpdate, accountID, networkID)
|
||||
err = transaction.DeleteNetwork(ctx, accountID, networkID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete network: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
@@ -12,10 +12,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
@@ -57,7 +57,7 @@ func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, u
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthShare, accountID, networkID)
|
||||
return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) {
|
||||
@@ -69,7 +69,7 @@ func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, u
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID)
|
||||
return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) {
|
||||
@@ -81,7 +81,7 @@ func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID,
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID)
|
||||
resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network resources: %w", err)
|
||||
}
|
||||
@@ -113,7 +113,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
|
||||
var eventsToStore []func()
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name)
|
||||
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
|
||||
if err == nil {
|
||||
return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
|
||||
}
|
||||
@@ -123,7 +123,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource)
|
||||
err = transaction.SaveNetworkResource(ctx, resource)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save network resource: %w", err)
|
||||
}
|
||||
@@ -145,7 +145,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
eventsToStore = append(eventsToStore, event)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, resource.AccountID)
|
||||
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
@@ -174,7 +174,7 @@ func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networ
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID)
|
||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network resource: %w", err)
|
||||
}
|
||||
@@ -218,22 +218,22 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID)
|
||||
}
|
||||
|
||||
_, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID)
|
||||
_, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network resource: %w", err)
|
||||
}
|
||||
|
||||
oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name)
|
||||
oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
|
||||
if err == nil && oldResource.ID != resource.ID {
|
||||
return status.Errorf(status.InvalidArgument, "new resource name already exists")
|
||||
}
|
||||
|
||||
oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID)
|
||||
oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network resource: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource)
|
||||
err = transaction.SaveNetworkResource(ctx, resource)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save network resource: %w", err)
|
||||
}
|
||||
@@ -248,7 +248,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network))
|
||||
})
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, resource.AccountID)
|
||||
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
@@ -325,7 +325,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
return fmt.Errorf("failed to delete resource: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
@@ -375,7 +375,7 @@ func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transacti
|
||||
eventsToStore = append(eventsToStore, event)
|
||||
}
|
||||
|
||||
err = transaction.DeleteNetworkResource(ctx, store.LockingStrengthUpdate, accountID, resourceID)
|
||||
err = transaction.DeleteNetworkResource(ctx, accountID, resourceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to delete network resource: %w", err)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
|
||||
@@ -8,13 +8,13 @@ import (
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
nbDomain "github.com/netbirdio/netbird/management/domain"
|
||||
nbDomain "github.com/netbirdio/netbird/shared/management/domain"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
type NetworkResourceType string
|
||||
|
||||
@@ -14,8 +14,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
@@ -54,7 +54,7 @@ func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, use
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthShare, accountID, networkID)
|
||||
return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) {
|
||||
@@ -66,7 +66,7 @@ func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, use
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
|
||||
routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network routers: %w", err)
|
||||
}
|
||||
@@ -93,7 +93,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
var network *networkTypes.Network
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID)
|
||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
@@ -104,12 +104,12 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
router.ID = xid.New().String()
|
||||
|
||||
err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router)
|
||||
err = transaction.SaveNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create network router: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, router.AccountID)
|
||||
err = transaction.IncrementNetworkSerial(ctx, router.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
@@ -136,7 +136,7 @@ func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkI
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthShare, accountID, routerID)
|
||||
router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthNone, accountID, routerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network router: %w", err)
|
||||
}
|
||||
@@ -162,7 +162,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
var network *networkTypes.Network
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID)
|
||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
@@ -171,12 +171,12 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
|
||||
}
|
||||
|
||||
err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router)
|
||||
err = transaction.SaveNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update network router: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, router.AccountID)
|
||||
err = transaction.IncrementNetworkSerial(ctx, router.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
@@ -213,7 +213,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
||||
return fmt.Errorf("failed to delete network router: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
@@ -232,7 +232,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) {
|
||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID)
|
||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
@@ -246,7 +246,7 @@ func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction
|
||||
return nil, status.NewRouterNotPartOfNetworkError(routerID, networkID)
|
||||
}
|
||||
|
||||
err = transaction.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, routerID)
|
||||
err = transaction.DeleteNetworkRouter(ctx, accountID, routerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to delete network router: %w", err)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/networks/types"
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ package types
|
||||
import (
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
type Network struct {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -20,14 +20,14 @@ type Peer struct {
|
||||
// WireGuard public key
|
||||
Key string `gorm:"index"`
|
||||
// IP address of the Peer
|
||||
IP net.IP `gorm:"serializer:json"`
|
||||
IP net.IP `gorm:"serializer:json"` // uniqueness index per accountID (check migrations)
|
||||
// Meta is a Peer system meta data
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
// Name is peer's name (machine name)
|
||||
Name string
|
||||
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
||||
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||
DNSLabel string
|
||||
DNSLabel string // uniqueness index per accountID (check migrations)
|
||||
// Status peer's management connection status
|
||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
|
||||
// The user ID that registered the peer
|
||||
@@ -94,6 +94,22 @@ type File struct {
|
||||
ProcessIsRunning bool
|
||||
}
|
||||
|
||||
// Flags defines a set of options to control feature behavior
|
||||
type Flags struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed bool
|
||||
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
DisableDNS bool
|
||||
DisableFirewall bool
|
||||
BlockLANAccess bool
|
||||
BlockInbound bool
|
||||
|
||||
LazyConnectionEnabled bool
|
||||
}
|
||||
|
||||
// PeerSystemMeta is a metadata of a Peer machine system
|
||||
type PeerSystemMeta struct { //nolint:revive
|
||||
Hostname string
|
||||
@@ -111,6 +127,7 @@ type PeerSystemMeta struct { //nolint:revive
|
||||
SystemProductName string
|
||||
SystemManufacturer string
|
||||
Environment Environment `gorm:"serializer:json"`
|
||||
Flags Flags `gorm:"serializer:json"`
|
||||
Files []File `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
@@ -155,7 +172,8 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||
p.SystemProductName == other.SystemProductName &&
|
||||
p.SystemManufacturer == other.SystemManufacturer &&
|
||||
p.Environment.Cloud == other.Environment.Cloud &&
|
||||
p.Environment.Platform == other.Environment.Platform
|
||||
p.Environment.Platform == other.Environment.Platform &&
|
||||
p.Flags.isEqual(other.Flags)
|
||||
}
|
||||
|
||||
func (p PeerSystemMeta) isEmpty() bool {
|
||||
@@ -315,3 +333,16 @@ func (p *Peer) UpdateLastLogin() *Peer {
|
||||
p.Status = newStatus
|
||||
return p
|
||||
}
|
||||
|
||||
func (f Flags) isEqual(other Flags) bool {
|
||||
return f.RosenpassEnabled == other.RosenpassEnabled &&
|
||||
f.RosenpassPermissive == other.RosenpassPermissive &&
|
||||
f.ServerSSHAllowed == other.ServerSSHAllowed &&
|
||||
f.DisableClientRoutes == other.DisableClientRoutes &&
|
||||
f.DisableServerRoutes == other.DisableServerRoutes &&
|
||||
f.DisableDNS == other.DisableDNS &&
|
||||
f.DisableFirewall == other.DisableFirewall &&
|
||||
f.BlockLANAccess == other.BlockLANAccess &&
|
||||
f.BlockInbound == other.BlockInbound &&
|
||||
f.LazyConnectionEnabled == other.LazyConnectionEnabled
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// FQDNOld is the original implementation for benchmarking purposes
|
||||
@@ -83,3 +85,59 @@ func TestIsEqual(t *testing.T) {
|
||||
t.Error("meta1 should be equal to meta2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlags_IsEqual(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
f1 Flags
|
||||
f2 Flags
|
||||
expect bool
|
||||
}{
|
||||
{
|
||||
name: "should be equal when all fields are identical",
|
||||
f1: Flags{
|
||||
RosenpassEnabled: true, RosenpassPermissive: false, ServerSSHAllowed: true,
|
||||
DisableClientRoutes: false, DisableServerRoutes: true, DisableDNS: false,
|
||||
DisableFirewall: true, BlockLANAccess: false, BlockInbound: true, LazyConnectionEnabled: true,
|
||||
},
|
||||
f2: Flags{
|
||||
RosenpassEnabled: true, RosenpassPermissive: false, ServerSSHAllowed: true,
|
||||
DisableClientRoutes: false, DisableServerRoutes: true, DisableDNS: false,
|
||||
DisableFirewall: true, BlockLANAccess: false, BlockInbound: true, LazyConnectionEnabled: true,
|
||||
},
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "shouldn't be equal when fields are different",
|
||||
f1: Flags{
|
||||
RosenpassEnabled: true, RosenpassPermissive: false, ServerSSHAllowed: true,
|
||||
DisableClientRoutes: false, DisableServerRoutes: true, DisableDNS: false,
|
||||
DisableFirewall: true, BlockLANAccess: false, BlockInbound: true, LazyConnectionEnabled: true,
|
||||
},
|
||||
f2: Flags{
|
||||
RosenpassEnabled: false, RosenpassPermissive: true, ServerSSHAllowed: false,
|
||||
DisableClientRoutes: true, DisableServerRoutes: false, DisableDNS: true,
|
||||
DisableFirewall: false, BlockLANAccess: true, BlockInbound: false, LazyConnectionEnabled: false,
|
||||
},
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "should be equal when both are empty",
|
||||
f1: Flags{},
|
||||
f2: Flags{},
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "shouldn't be equal when at least one field differs",
|
||||
f1: Flags{RosenpassEnabled: true},
|
||||
f2: Flags{RosenpassEnabled: false},
|
||||
expect: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.expect, tt.f1.isEqual(tt.f2))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,10 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -18,11 +22,14 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
@@ -31,8 +38,6 @@ import (
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
@@ -40,6 +45,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func TestPeer_LoginExpired(t *testing.T) {
|
||||
@@ -303,12 +310,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
group1.Peers = append(group1.Peers, peer1.ID)
|
||||
group2.Peers = append(group2.Peers, peer2.ID)
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group1, true)
|
||||
err = manager.CreateGroup(context.Background(), account.Id, userID, &group1)
|
||||
if err != nil {
|
||||
t.Errorf("expecting group1 to be added, got failure %v", err)
|
||||
return
|
||||
}
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group2, true)
|
||||
err = manager.CreateGroup(context.Background(), account.Id, userID, &group2)
|
||||
if err != nil {
|
||||
t.Errorf("expecting group2 to be added, got failure %v", err)
|
||||
return
|
||||
@@ -479,7 +486,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account.Users[someUser] = &types.User{
|
||||
Id: someUser,
|
||||
Role: types.UserRoleUser,
|
||||
@@ -666,7 +673,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account.Users[someUser] = &types.User{
|
||||
Id: someUser,
|
||||
Role: testCase.role,
|
||||
@@ -736,7 +743,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
|
||||
adminUser := "account_creator"
|
||||
regularUser := "regular_user"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account.Users[regularUser] = &types.User{
|
||||
Id: regularUser,
|
||||
Role: types.UserRoleUser,
|
||||
@@ -1156,8 +1163,8 @@ func TestToSyncResponse(t *testing.T) {
|
||||
},
|
||||
}
|
||||
dnsCache := &DNSConfigCache{}
|
||||
|
||||
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, true, nil)
|
||||
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
|
||||
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil)
|
||||
|
||||
assert.NotNil(t, response)
|
||||
// assert peer config
|
||||
@@ -1266,7 +1273,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1290,15 +1297,21 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
SSHEnabled: false,
|
||||
LastLogin: util.ToPtr(time.Now()),
|
||||
ExtraDNSLabels: []string{
|
||||
"extraLabel1",
|
||||
"extraLabel2",
|
||||
},
|
||||
}
|
||||
|
||||
addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels)
|
||||
|
||||
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, addedPeer.Key)
|
||||
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, addedPeer.Key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, peer.AccountID, existingAccountID)
|
||||
assert.Equal(t, peer.UserID, existingUserID)
|
||||
assert.Equal(t, newPeer.ExtraDNSLabels, peer.ExtraDNSLabels)
|
||||
|
||||
account, err := s.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
@@ -1333,21 +1346,23 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.
|
||||
EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
_, err = s.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
newPeerTemplate := &nbpeer.Peer{
|
||||
AccountID: existingAccountID,
|
||||
Key: "newPeerKey",
|
||||
UserID: "",
|
||||
IP: net.IP{123, 123, 123, 123},
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
@@ -1358,35 +1373,113 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
DNSLabel: "newPeer.test",
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
SSHEnabled: false,
|
||||
ExtraDNSLabels: []string{
|
||||
"extraLabel1",
|
||||
"extraLabel2",
|
||||
},
|
||||
}
|
||||
|
||||
addedPeer, _, _, err := am.AddPeer(context.Background(), existingSetupKeyID, "", newPeer)
|
||||
testCases := []struct {
|
||||
name string
|
||||
existingSetupKeyID string
|
||||
expectedGroupIDsInAccount []string
|
||||
expectAddPeerError bool
|
||||
errorType status.Type
|
||||
expectedErrorMsgSubstring string
|
||||
}{
|
||||
{
|
||||
name: "Successful registration with setup key allowing extra DNS labels",
|
||||
existingSetupKeyID: "A2C8E62B-38F5-4553-B31E-DD66C696CEBD",
|
||||
expectAddPeerError: false,
|
||||
expectedGroupIDsInAccount: []string{"cfefqs706sqkneg59g2g", "cfefqs706sqkneg59g4g"},
|
||||
},
|
||||
{
|
||||
name: "Failed registration with setup key not allowing extra DNS labels",
|
||||
existingSetupKeyID: "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
||||
expectAddPeerError: true,
|
||||
errorType: status.PreconditionFailed,
|
||||
expectedErrorMsgSubstring: "setup key doesn't allow extra DNS labels",
|
||||
},
|
||||
{
|
||||
name: "Absent setup key",
|
||||
existingSetupKeyID: "AAAAAAAA-38F5-4553-B31E-DD66C696CEBB",
|
||||
expectAddPeerError: true,
|
||||
errorType: status.NotFound,
|
||||
expectedErrorMsgSubstring: "couldn't add peer: setup key is invalid",
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
currentPeer := &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
AccountID: newPeerTemplate.AccountID,
|
||||
Key: "newPeerKey_" + xid.New().String(),
|
||||
UserID: newPeerTemplate.UserID,
|
||||
IP: newPeerTemplate.IP,
|
||||
Meta: newPeerTemplate.Meta,
|
||||
Name: newPeerTemplate.Name,
|
||||
DNSLabel: newPeerTemplate.DNSLabel,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
SSHEnabled: newPeerTemplate.SSHEnabled,
|
||||
ExtraDNSLabels: newPeerTemplate.ExtraDNSLabels,
|
||||
}
|
||||
|
||||
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, peer.AccountID, existingAccountID)
|
||||
addedPeer, _, _, err := am.AddPeer(context.Background(), tc.existingSetupKeyID, "", currentPeer)
|
||||
|
||||
account, err := s.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, account.Peers, addedPeer.ID)
|
||||
assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID)
|
||||
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
|
||||
if tc.expectAddPeerError {
|
||||
require.Error(t, err, "Expected an error when adding peer with setup key: %s", tc.existingSetupKeyID)
|
||||
assert.Contains(t, err.Error(), tc.expectedErrorMsgSubstring, "Error message mismatch")
|
||||
e, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatal("Failed to map error")
|
||||
}
|
||||
assert.Equal(t, e.Type(), tc.errorType)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, uint64(1), account.Network.Serial)
|
||||
require.NoError(t, err, "Expected no error when adding peer with setup key: %s", tc.existingSetupKeyID)
|
||||
assert.NotNil(t, addedPeer, "addedPeer should not be nil on success")
|
||||
assert.Equal(t, currentPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch")
|
||||
|
||||
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
|
||||
assert.NoError(t, err)
|
||||
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, currentPeer.Key)
|
||||
require.NoError(t, err, "Failed to get peer by pub key: %s", currentPeer.Key)
|
||||
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
|
||||
assert.Equal(t, currentPeer.ExtraDNSLabels, peerFromStore.ExtraDNSLabels, "ExtraDNSLabels mismatch for peer from store")
|
||||
assert.Equal(t, addedPeer.ID, peerFromStore.ID, "Peer ID mismatch between addedPeer and peerFromStore")
|
||||
|
||||
hashedKey := sha256.Sum256([]byte(existingSetupKeyID))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
assert.NotEqual(t, lastUsed, account.SetupKeys[encodedHashedKey].LastUsed)
|
||||
assert.Equal(t, 1, account.SetupKeys[encodedHashedKey].UsedTimes)
|
||||
account, err := s.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err, "Failed to get account: %s", existingAccountID)
|
||||
assert.Contains(t, account.Peers, addedPeer.ID, "Peer ID not found in account.Peers")
|
||||
|
||||
for _, groupID := range tc.expectedGroupIDsInAccount {
|
||||
require.NotNil(t, account.Groups[groupID], "Group %s not found in account", groupID)
|
||||
assert.Contains(t, account.Groups[groupID].Peers, addedPeer.ID, "Peer ID %s not found in group %s", addedPeer.ID, groupID)
|
||||
}
|
||||
|
||||
assert.Equal(t, uint64(1), account.Network.Serial, "Network.Serial mismatch; this assumes specific initial state or increment logic.")
|
||||
|
||||
hashedKey := sha256.Sum256([]byte(tc.existingSetupKeyID))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
setupKeyData, ok := account.SetupKeys[encodedHashedKey]
|
||||
require.True(t, ok, "Setup key data not found in account.SetupKeys for key ID %s (encoded: %s)", tc.existingSetupKeyID, encodedHashedKey)
|
||||
|
||||
var zeroTime time.Time
|
||||
assert.NotEqual(t, zeroTime, setupKeyData.LastUsed, "Setup key LastUsed time should have been updated and not be zero.")
|
||||
|
||||
assert.Equal(t, 1, setupKeyData.UsedTimes, "Setup key UsedTimes should be 1 after first use.")
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
engine := os.Getenv("NETBIRD_STORE_ENGINE")
|
||||
if engine == "sqlite" || engine == "mysql" || engine == "" {
|
||||
// we intentionally disabled foreign keys in mysql
|
||||
t.Skip("Skipping test because store is not respecting foreign keys")
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
@@ -1408,7 +1501,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1436,7 +1529,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key)
|
||||
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key)
|
||||
require.Error(t, err)
|
||||
|
||||
account, err := s.GetAccount(context.Background(), existingAccountID)
|
||||
@@ -1456,13 +1549,172 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
assert.Equal(t, 0, account.SetupKeys[encodedHashedKey].UsedTimes)
|
||||
}
|
||||
|
||||
func Test_LoginPeer(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.
|
||||
EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
_, err = s.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err, "Failed to get existing account, check testdata/extended-store.sql. Account ID: %s", existingAccountID)
|
||||
|
||||
baseMeta := nbpeer.PeerSystemMeta{
|
||||
Hostname: "loginPeerHost",
|
||||
GoOS: "linux",
|
||||
}
|
||||
|
||||
newPeerTemplate := &nbpeer.Peer{
|
||||
AccountID: existingAccountID,
|
||||
UserID: "",
|
||||
IP: net.IP{123, 123, 123, 123},
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "newPeer",
|
||||
GoOS: "linux",
|
||||
},
|
||||
Name: "newPeerName",
|
||||
DNSLabel: "newPeer.test",
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
SSHEnabled: false,
|
||||
ExtraDNSLabels: []string{
|
||||
"extraLabel1",
|
||||
"extraLabel2",
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupKey string
|
||||
wireGuardPubKey string
|
||||
expectExtraDNSLabelsMismatch bool
|
||||
extraDNSLabels []string
|
||||
expectLoginError bool
|
||||
expectedErrorMsgSubstring string
|
||||
}{
|
||||
{
|
||||
name: "Successful login with setup key",
|
||||
setupKey: "A2C8E62B-38F5-4553-B31E-DD66C696CEBD",
|
||||
expectLoginError: false,
|
||||
},
|
||||
{
|
||||
name: "Successful login with setup key with DNS labels mismatch",
|
||||
setupKey: "A2C8E62B-38F5-4553-B31E-DD66C696CEBD",
|
||||
expectExtraDNSLabelsMismatch: true,
|
||||
extraDNSLabels: []string{"anotherLabel1", "anotherLabel2"},
|
||||
expectLoginError: false,
|
||||
},
|
||||
{
|
||||
name: "Failed login with setup key not allowing extra DNS labels",
|
||||
setupKey: "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
||||
expectExtraDNSLabelsMismatch: true,
|
||||
extraDNSLabels: []string{"anotherLabel1", "anotherLabel2"},
|
||||
expectLoginError: true,
|
||||
expectedErrorMsgSubstring: "setup key doesn't allow extra DNS labels",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
currentWireGuardPubKey := "testPubKey_" + xid.New().String()
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
upperKey := strings.ToUpper(tc.setupKey)
|
||||
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
sk, err := s.GetSetupKeyBySecret(context.Background(), store.LockingStrengthUpdate, encodedHashedKey)
|
||||
require.NoError(t, err, "Failed to get setup key %s from storage", tc.setupKey)
|
||||
|
||||
currentPeer := &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
AccountID: newPeerTemplate.AccountID,
|
||||
Key: currentWireGuardPubKey,
|
||||
UserID: newPeerTemplate.UserID,
|
||||
IP: newPeerTemplate.IP,
|
||||
Meta: newPeerTemplate.Meta,
|
||||
Name: newPeerTemplate.Name,
|
||||
DNSLabel: newPeerTemplate.DNSLabel,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
SSHEnabled: newPeerTemplate.SSHEnabled,
|
||||
}
|
||||
// add peer manually to bypass creation during login stage
|
||||
if sk.AllowExtraDNSLabels {
|
||||
currentPeer.ExtraDNSLabels = newPeerTemplate.ExtraDNSLabels
|
||||
}
|
||||
_, _, _, err = am.AddPeer(context.Background(), tc.setupKey, "", currentPeer)
|
||||
require.NoError(t, err, "Expected no error when adding peer with setup key: %s", tc.setupKey)
|
||||
|
||||
loginInput := types.PeerLogin{
|
||||
WireGuardPubKey: currentWireGuardPubKey,
|
||||
SSHKey: "test-ssh-key",
|
||||
Meta: baseMeta,
|
||||
UserID: "",
|
||||
SetupKey: tc.setupKey,
|
||||
ConnectionIP: net.ParseIP("192.0.2.100"),
|
||||
}
|
||||
|
||||
if tc.expectExtraDNSLabelsMismatch {
|
||||
loginInput.ExtraDNSLabels = tc.extraDNSLabels
|
||||
}
|
||||
|
||||
loggedinPeer, networkMap, postureChecks, loginErr := am.LoginPeer(context.Background(), loginInput)
|
||||
if tc.expectLoginError {
|
||||
require.Error(t, loginErr, "Expected an error during LoginPeer with setup key: %s", tc.setupKey)
|
||||
assert.Contains(t, loginErr.Error(), tc.expectedErrorMsgSubstring, "Error message mismatch")
|
||||
assert.Nil(t, loggedinPeer, "LoggedinPeer should be nil on error")
|
||||
assert.Nil(t, networkMap, "NetworkMap should be nil on error")
|
||||
assert.Nil(t, postureChecks, "PostureChecks should be empty or nil on error")
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, loginErr, "Expected no error during LoginPeer with setup key: %s", tc.setupKey)
|
||||
assert.NotNil(t, loggedinPeer, "loggedinPeer should not be nil on success")
|
||||
if tc.expectExtraDNSLabelsMismatch {
|
||||
assert.NotEqual(t, tc.extraDNSLabels, loggedinPeer.ExtraDNSLabels, "ExtraDNSLabels should not match on loggedinPeer")
|
||||
assert.Equal(t, currentPeer.ExtraDNSLabels, loggedinPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch on loggedinPeer")
|
||||
} else {
|
||||
assert.Equal(t, currentPeer.ExtraDNSLabels, loggedinPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch on loggedinPeer")
|
||||
}
|
||||
assert.NotNil(t, networkMap, "networkMap should not be nil on success")
|
||||
|
||||
assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer")
|
||||
|
||||
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, loginInput.WireGuardPubKey)
|
||||
require.NoError(t, err, "Failed to get peer by pub key: %s", loginInput.WireGuardPubKey)
|
||||
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
|
||||
assert.Equal(t, loggedinPeer.ID, peerFromStore.ID, "Peer ID mismatch between loggedinPeer and peerFromStore")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
g := []*types.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
@@ -1478,8 +1730,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
for _, group := range g {
|
||||
err = manager.CreateGroup(context.Background(), account.Id, userID, group)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// create a user with auto groups
|
||||
_, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*types.User{
|
||||
@@ -1538,7 +1793,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg) //
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1601,7 +1856,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
return update, true, nil
|
||||
}
|
||||
|
||||
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireUpdateFunc}
|
||||
manager.integratedPeerValidator = MockIntegratedValidator{ValidatePeerFunc: requireUpdateFunc}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
@@ -1623,7 +1878,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
return update, false, nil
|
||||
}
|
||||
|
||||
manager.integratedPeerValidator = MocIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc}
|
||||
manager.integratedPeerValidator = MockIntegratedValidator{ValidatePeerFunc: requireNoUpdateFunc}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
@@ -1829,15 +2084,19 @@ func Test_DeletePeer(t *testing.T) {
|
||||
// account with an admin and a regular user
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account.Peers = map[string]*nbpeer.Peer{
|
||||
"peer1": {
|
||||
ID: "peer1",
|
||||
AccountID: accountID,
|
||||
IP: net.IP{1, 1, 1, 1},
|
||||
DNSLabel: "peer1.test",
|
||||
},
|
||||
"peer2": {
|
||||
ID: "peer2",
|
||||
AccountID: accountID,
|
||||
IP: net.IP{2, 2, 2, 2},
|
||||
DNSLabel: "peer2.test",
|
||||
},
|
||||
}
|
||||
account.Groups = map[string]*types.Group{
|
||||
@@ -1867,3 +2126,264 @@ func Test_DeletePeer(t *testing.T) {
|
||||
assert.NotContains(t, group.Peers, "peer1")
|
||||
|
||||
}
|
||||
|
||||
func Test_IsUniqueConstraintError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
engine types.Engine
|
||||
}{
|
||||
{
|
||||
name: "PostgreSQL uniqueness error",
|
||||
engine: types.PostgresStoreEngine,
|
||||
},
|
||||
{
|
||||
name: "MySQL uniqueness error",
|
||||
engine: types.MysqlStoreEngine,
|
||||
},
|
||||
{
|
||||
name: "SQLite uniqueness error",
|
||||
engine: types.SqliteStoreEngine,
|
||||
},
|
||||
}
|
||||
|
||||
peer := &nbpeer.Peer{
|
||||
ID: "test-peer-id",
|
||||
AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
DNSLabel: "test-peer-dns-label",
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(tt.engine))
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
err = s.AddPeerToAccount(context.Background(), peer)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = s.AddPeerToAccount(context.Background(), peer)
|
||||
result := isUniqueConstraintError(err)
|
||||
assert.True(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_AddPeer(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatalf("error creating account: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", types.SetupKeyReusable, time.Hour, nil, 10000, userID, false, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 300
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
AccountID: accountID,
|
||||
Key: "key" + strconv.Itoa(i),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(i), GoOS: "linux"},
|
||||
}
|
||||
|
||||
<-start
|
||||
|
||||
_, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", newPeer)
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers))
|
||||
|
||||
seenIP := make(map[string]bool)
|
||||
for _, p := range account.Peers {
|
||||
ipStr := p.IP.String()
|
||||
if seenIP[ipStr] {
|
||||
t.Fatalf("Duplicate IP found in account %s: %s", accountID, ipStr)
|
||||
}
|
||||
seenIP[ipStr] = true
|
||||
}
|
||||
|
||||
seenLabel := make(map[string]bool)
|
||||
for _, p := range account.Peers {
|
||||
if seenLabel[p.DNSLabel] {
|
||||
t.Fatalf("Duplicate Label found in account %s: %s", accountID, p.DNSLabel)
|
||||
}
|
||||
seenLabel[p.DNSLabel] = true
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, maps.Values(account.SetupKeys)[0].UsedTimes)
|
||||
assert.Equal(t, uint64(totalPeers), account.Network.Serial)
|
||||
}
|
||||
|
||||
func TestBufferUpdateAccountPeers(t *testing.T) {
|
||||
const (
|
||||
peersCount = 1000
|
||||
updateAccountInterval = 50 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32
|
||||
uapLastRun, dpLastRun atomic.Int64
|
||||
|
||||
totalNewRuns, totalOldRuns int
|
||||
)
|
||||
|
||||
uap := func(ctx context.Context, accountID string) {
|
||||
updatePeersDeleted.Store(deletedPeers.Load())
|
||||
updatePeersRuns.Add(1)
|
||||
uapLastRun.Store(time.Now().UnixMilli())
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Run("new approach", func(t *testing.T) {
|
||||
updatePeersRuns.Store(0)
|
||||
updatePeersDeleted.Store(0)
|
||||
deletedPeers.Store(0)
|
||||
|
||||
var mustore sync.Map
|
||||
bufupd := func(ctx context.Context, accountID string) {
|
||||
mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := mu.(*bufferUpdate)
|
||||
|
||||
if !b.mu.TryLock() {
|
||||
b.update.Store(true)
|
||||
return
|
||||
}
|
||||
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
uap(ctx, accountID)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
b.next = time.AfterFunc(updateAccountInterval, func() {
|
||||
uap(ctx, accountID)
|
||||
})
|
||||
}()
|
||||
}
|
||||
dp := func(ctx context.Context, accountID, peerID, userID string) error {
|
||||
deletedPeers.Add(1)
|
||||
dpLastRun.Store(time.Now().UnixMilli())
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
bufupd(ctx, accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
am := mock_server.MockAccountManager{
|
||||
UpdateAccountPeersFunc: uap,
|
||||
BufferUpdateAccountPeersFunc: bufupd,
|
||||
DeletePeerFunc: dp,
|
||||
}
|
||||
empty := ""
|
||||
for range peersCount {
|
||||
//nolint
|
||||
am.DeletePeer(context.Background(), empty, empty, empty)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
|
||||
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
|
||||
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
|
||||
|
||||
totalNewRuns = int(updatePeersRuns.Load())
|
||||
})
|
||||
|
||||
t.Run("old approach", func(t *testing.T) {
|
||||
updatePeersRuns.Store(0)
|
||||
updatePeersDeleted.Store(0)
|
||||
deletedPeers.Store(0)
|
||||
|
||||
var mustore sync.Map
|
||||
bufupd := func(ctx context.Context, accountID string) {
|
||||
mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{})
|
||||
b := mu.(*sync.Mutex)
|
||||
|
||||
if !b.TryLock() {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(updateAccountInterval)
|
||||
b.Unlock()
|
||||
uap(ctx, accountID)
|
||||
}()
|
||||
}
|
||||
dp := func(ctx context.Context, accountID, peerID, userID string) error {
|
||||
deletedPeers.Add(1)
|
||||
dpLastRun.Store(time.Now().UnixMilli())
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
bufupd(ctx, accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
am := mock_server.MockAccountManager{
|
||||
UpdateAccountPeersFunc: uap,
|
||||
BufferUpdateAccountPeersFunc: bufupd,
|
||||
DeletePeerFunc: dp,
|
||||
}
|
||||
empty := ""
|
||||
for range peersCount {
|
||||
//nolint
|
||||
am.DeletePeer(context.Background(), empty, empty, empty)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
|
||||
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
|
||||
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
|
||||
|
||||
totalOldRuns = int(updatePeersRuns.Load())
|
||||
})
|
||||
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
|
||||
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
|
||||
}
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
@@ -42,7 +42,7 @@ func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID str
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
|
||||
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
|
||||
@@ -52,12 +52,12 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return m.store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID)
|
||||
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
|
||||
}
|
||||
|
||||
return m.store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
|
||||
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
|
||||
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID)
|
||||
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
|
||||
}
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
@@ -45,7 +45,7 @@ func (m *managerImpl) ValidateUserPermissions(
|
||||
return true, nil
|
||||
}
|
||||
|
||||
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -6,15 +6,15 @@ import (
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// GetPolicy from the store
|
||||
@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID)
|
||||
return am.Store.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policyID)
|
||||
}
|
||||
|
||||
// SavePolicy in the store
|
||||
@@ -61,7 +61,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
saveFunc = transaction.SavePolicy
|
||||
}
|
||||
|
||||
return saveFunc(ctx, store.LockingStrengthUpdate, policy)
|
||||
return saveFunc(ctx, policy)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -113,11 +113,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeletePolicy(ctx, store.LockingStrengthUpdate, accountID, policyID)
|
||||
return transaction.DeletePolicy(ctx, accountID, policyID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -142,13 +142,13 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
|
||||
if isUpdate {
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID)
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -173,7 +173,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
|
||||
// validatePolicy validates the policy and its rules.
|
||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
|
||||
if policy.ID != "" {
|
||||
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID)
|
||||
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -182,12 +182,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
||||
policy.AccountID = accountID
|
||||
}
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups())
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks)
|
||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
ID: "peerB",
|
||||
IP: net.ParseIP("100.65.80.39"),
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.48.0"},
|
||||
},
|
||||
"peerC": {
|
||||
ID: "peerC",
|
||||
@@ -58,6 +59,17 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
IP: net.ParseIP("100.65.29.55"),
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
},
|
||||
"peerI": {
|
||||
ID: "peerI",
|
||||
IP: net.ParseIP("100.65.31.2"),
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
},
|
||||
"peerK": {
|
||||
ID: "peerK",
|
||||
IP: net.ParseIP("100.32.80.1"),
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.30.0"},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"GroupAll": {
|
||||
@@ -99,6 +111,20 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
"peerH",
|
||||
},
|
||||
},
|
||||
"GroupDMZ": {
|
||||
ID: "GroupDMZ",
|
||||
Name: "dmz",
|
||||
Peers: []string{
|
||||
"peerI",
|
||||
},
|
||||
},
|
||||
"GroupWorkflow": {
|
||||
ID: "GroupWorkflow",
|
||||
Name: "workflow",
|
||||
Peers: []string{
|
||||
"peerK",
|
||||
},
|
||||
},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
{
|
||||
@@ -148,6 +174,68 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "RuleDMZ",
|
||||
Name: "Dmz",
|
||||
Description: "No description",
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: "RuleDMZ",
|
||||
Name: "Dmz",
|
||||
Description: "No description",
|
||||
Bidirectional: true,
|
||||
Enabled: true,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
PortRanges: []types.RulePortRange{
|
||||
{
|
||||
Start: 8080,
|
||||
End: 8083,
|
||||
},
|
||||
},
|
||||
Sources: []string{
|
||||
"GroupWorkstations",
|
||||
},
|
||||
Destinations: []string{
|
||||
"GroupDMZ",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "RuleWorkflow",
|
||||
Name: "Workflow",
|
||||
Description: "No description",
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: "RuleWorkflow",
|
||||
Name: "Workflow",
|
||||
Description: "No description",
|
||||
Bidirectional: true,
|
||||
Enabled: true,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
PortRanges: []types.RulePortRange{
|
||||
{
|
||||
Start: 8088,
|
||||
End: 8088,
|
||||
},
|
||||
{
|
||||
Start: 9090,
|
||||
End: 9095,
|
||||
},
|
||||
},
|
||||
Sources: []string{
|
||||
"GroupWorkflow",
|
||||
},
|
||||
Destinations: []string{
|
||||
"GroupDMZ",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -158,15 +246,15 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
|
||||
t.Run("check that all peers get map", func(t *testing.T) {
|
||||
for _, p := range account.Peers {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers)
|
||||
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present")
|
||||
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present")
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers)
|
||||
assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
|
||||
assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check first peer map details", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers)
|
||||
assert.Len(t, peers, 7)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers)
|
||||
assert.Len(t, peers, 8)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
@@ -174,8 +262,9 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
assert.Contains(t, peers, account.Peers["peerF"])
|
||||
assert.Contains(t, peers, account.Peers["peerG"])
|
||||
assert.Contains(t, peers, account.Peers["peerH"])
|
||||
assert.Contains(t, peers, account.Peers["peerI"])
|
||||
|
||||
epectedFirewallRules := []*types.FirewallRule{
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
@@ -292,12 +381,28 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
Port: "",
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.31.2",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
PortRange: types.RulePortRange{Start: 8080, End: 8083},
|
||||
PolicyID: "RuleDMZ",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.31.2",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
PortRange: types.RulePortRange{Start: 8080, End: 8083},
|
||||
PolicyID: "RuleDMZ",
|
||||
},
|
||||
}
|
||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
||||
assert.Len(t, firewallRules, len(expectedFirewallRules))
|
||||
|
||||
for _, rule := range firewallRules {
|
||||
contains := false
|
||||
for _, expectedRule := range epectedFirewallRules {
|
||||
for _, expectedRule := range expectedFirewallRules {
|
||||
if rule.Equal(expectedRule) {
|
||||
contains = true
|
||||
break
|
||||
@@ -306,6 +411,32 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
assert.True(t, contains, "rule not found in expected rules %#v", rule)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check port ranges support for older peers", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers)
|
||||
assert.Len(t, peers, 1)
|
||||
assert.Contains(t, peers, account.Peers["peerI"])
|
||||
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "100.65.31.2",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "8088",
|
||||
PolicyID: "RuleWorkflow",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.31.2",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "8088",
|
||||
PolicyID: "RuleWorkflow",
|
||||
},
|
||||
}
|
||||
assert.ElementsMatch(t, firewallRules, expectedFirewallRules)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
@@ -408,10 +539,10 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("check first peer map", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
epectedFirewallRules := []*types.FirewallRule{
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
@@ -429,19 +560,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
}
|
||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
||||
slices.SortFunc(epectedFirewallRules, sortFunc())
|
||||
assert.Len(t, firewallRules, len(expectedFirewallRules))
|
||||
slices.SortFunc(expectedFirewallRules, sortFunc())
|
||||
slices.SortFunc(firewallRules, sortFunc())
|
||||
for i := range firewallRules {
|
||||
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
|
||||
assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check second peer map", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerB"])
|
||||
|
||||
epectedFirewallRules := []*types.FirewallRule{
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "100.65.80.39",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
@@ -459,21 +590,21 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
}
|
||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
||||
slices.SortFunc(epectedFirewallRules, sortFunc())
|
||||
assert.Len(t, firewallRules, len(expectedFirewallRules))
|
||||
slices.SortFunc(expectedFirewallRules, sortFunc())
|
||||
slices.SortFunc(firewallRules, sortFunc())
|
||||
for i := range firewallRules {
|
||||
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
|
||||
assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
|
||||
}
|
||||
})
|
||||
|
||||
account.Policies[1].Rules[0].Bidirectional = false
|
||||
|
||||
t.Run("check first peer map directional only", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
epectedFirewallRules := []*types.FirewallRule{
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
@@ -483,19 +614,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
}
|
||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
||||
slices.SortFunc(epectedFirewallRules, sortFunc())
|
||||
assert.Len(t, firewallRules, len(expectedFirewallRules))
|
||||
slices.SortFunc(expectedFirewallRules, sortFunc())
|
||||
slices.SortFunc(firewallRules, sortFunc())
|
||||
for i := range firewallRules {
|
||||
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
|
||||
assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check second peer map directional only", func(t *testing.T) {
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerB"])
|
||||
|
||||
epectedFirewallRules := []*types.FirewallRule{
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "100.65.80.39",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
@@ -505,11 +636,11 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
}
|
||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
||||
slices.SortFunc(epectedFirewallRules, sortFunc())
|
||||
assert.Len(t, firewallRules, len(expectedFirewallRules))
|
||||
slices.SortFunc(expectedFirewallRules, sortFunc())
|
||||
slices.SortFunc(firewallRules, sortFunc())
|
||||
for i := range firewallRules {
|
||||
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
|
||||
assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -690,7 +821,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
|
||||
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
|
||||
// will establish a connection with all source peers satisfying the NB posture check.
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -700,7 +831,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, 1)
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
@@ -717,7 +848,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -727,7 +858,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -742,19 +873,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
|
||||
// no connection should be established to any peer of destination group
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
||||
assert.Len(t, peers, 0)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
|
||||
// no connection should be established to any peer of destination group
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
|
||||
assert.Len(t, peers, 0)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
|
||||
|
||||
@@ -769,14 +900,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
|
||||
assert.Len(t, peers, 3)
|
||||
assert.Len(t, firewallRules, 3)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers)
|
||||
assert.Len(t, peers, 5)
|
||||
// assert peers from Group Swarm
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
@@ -862,7 +993,7 @@ func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int {
|
||||
func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
g := []*types.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
@@ -883,8 +1014,11 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
Name: "GroupD",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
},
|
||||
}, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
for _, group := range g {
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, group)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
@@ -894,6 +1028,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
var policyWithGroupRulesNoPeers *types.Policy
|
||||
var policyWithDestinationPeersOnly *types.Policy
|
||||
var policyWithSourceAndDestinationPeers *types.Policy
|
||||
var err error
|
||||
|
||||
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
|
||||
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"regexp"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -24,20 +24,12 @@ func sanitizeVersion(version string) string {
|
||||
}
|
||||
|
||||
func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) {
|
||||
peerVersion := sanitizeVersion(peer.Meta.WtVersion)
|
||||
minVersion := sanitizeVersion(n.MinVersion)
|
||||
|
||||
peerNBVersion, err := version.NewVersion(peerVersion)
|
||||
meetsMin, err := MeetsMinVersion(n.MinVersion, peer.Meta.WtVersion)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
constraints, err := version.NewConstraint(">= " + minVersion)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if constraints.Check(peerNBVersion) {
|
||||
if meetsMin {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -60,3 +52,21 @@ func (n *NBVersionCheck) Validate() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MeetsMinVersion checks if the peer's version meets or exceeds the minimum required version
|
||||
func MeetsMinVersion(minVer, peerVer string) (bool, error) {
|
||||
peerVer = sanitizeVersion(peerVer)
|
||||
minVer = sanitizeVersion(minVer)
|
||||
|
||||
peerNBVer, err := version.NewVersion(peerVer)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
constraints, err := version.NewConstraint(">= " + minVer)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return constraints.Check(peerNBVer), nil
|
||||
}
|
||||
|
||||
@@ -139,3 +139,68 @@ func TestNBVersionCheck_Validate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMeetsMinVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
minVer string
|
||||
peerVer string
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Peer version greater than min version",
|
||||
minVer: "0.26.0",
|
||||
peerVer: "0.60.1",
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Peer version equals min version",
|
||||
minVer: "1.0.0",
|
||||
peerVer: "1.0.0",
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Peer version less than min version",
|
||||
minVer: "1.0.0",
|
||||
peerVer: "0.9.9",
|
||||
want: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Peer version with pre-release tag greater than min version",
|
||||
minVer: "1.0.0",
|
||||
peerVer: "1.0.1-alpha",
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid peer version format",
|
||||
minVer: "1.0.0",
|
||||
peerVer: "dev",
|
||||
want: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid min version format",
|
||||
minVer: "invalid.version",
|
||||
peerVer: "1.0.0",
|
||||
want: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := MeetsMinVersion(tt.minVer, tt.peerVer)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"slices"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type PeerNetworkRangeCheck struct {
|
||||
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID)
|
||||
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
|
||||
}
|
||||
|
||||
// SavePostureChecks saves a posture check.
|
||||
@@ -62,7 +62,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
postureChecks.AccountID = accountID
|
||||
return transaction.SavePostureChecks(ctx, store.LockingStrengthUpdate, postureChecks)
|
||||
return transaction.SavePostureChecks(ctx, postureChecks)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -101,7 +101,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
var postureChecks *posture.Checks
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID)
|
||||
postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -110,11 +110,11 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeletePostureChecks(ctx, store.LockingStrengthUpdate, accountID, postureChecksID)
|
||||
return transaction.DeletePostureChecks(ctx, accountID, postureChecksID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -135,7 +135,7 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
||||
@@ -161,7 +161,7 @@ func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, pe
|
||||
|
||||
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
|
||||
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -190,14 +190,14 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account
|
||||
|
||||
// If the posture check already has an ID, verify its existence in the store.
|
||||
if postureChecks.ID != "" {
|
||||
if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecks.ID); err != nil {
|
||||
if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecks.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// For new posture checks, ensure no duplicates by name.
|
||||
checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID)
|
||||
checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -259,7 +259,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
||||
|
||||
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
|
||||
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
@@ -105,10 +105,14 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
|
||||
Id: regularUserID,
|
||||
Role: types.UserRoleUser,
|
||||
}
|
||||
peer1 := &peer.Peer{
|
||||
ID: "peer1",
|
||||
}
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
|
||||
account.Users[admin.Id] = admin
|
||||
account.Users[user.Id] = user
|
||||
account.Peers["peer1"] = peer1
|
||||
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -121,7 +125,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
|
||||
func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
g := []*types.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
@@ -137,8 +141,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
},
|
||||
}, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
for _, group := range g {
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, group)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
@@ -156,7 +163,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA, true)
|
||||
postureCheckA, err := manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
postureCheckB := &posture.Checks{
|
||||
@@ -449,14 +456,16 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
AccountID: account.Id,
|
||||
Peers: []string{"peer1"},
|
||||
}
|
||||
err = manager.CreateGroup(context.Background(), account.Id, adminUserID, groupA)
|
||||
require.NoError(t, err, "failed to create groupA")
|
||||
|
||||
groupB := &types.Group{
|
||||
ID: "groupB",
|
||||
AccountID: account.Id,
|
||||
Peers: []string{},
|
||||
}
|
||||
err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, []*types.Group{groupA, groupB})
|
||||
require.NoError(t, err, "failed to save groups")
|
||||
err = manager.CreateGroup(context.Background(), account.Id, adminUserID, groupB)
|
||||
require.NoError(t, err, "failed to create groupB")
|
||||
|
||||
postureCheckA := &posture.Checks{
|
||||
Name: "checkA",
|
||||
@@ -535,7 +544,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
|
||||
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
||||
groupA.Peers = []string{}
|
||||
err = manager.Store.SaveGroup(context.Background(), store.LockingStrengthUpdate, groupA)
|
||||
err = manager.UpdateGroup(context.Background(), account.Id, adminUserID, groupA)
|
||||
require.NoError(t, err, "failed to save groups")
|
||||
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user