Compare commits

...

82 Commits

Author SHA1 Message Date
bcmmbaga
1e24916dac Add missing store tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-27 15:40:46 +03:00
bcmmbaga
875b8d662c Clean up sqlite policy rules after deletion
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-27 15:40:37 +03:00
bcmmbaga
41b4e3177a Fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-27 12:53:50 +03:00
bcmmbaga
3186876d5e Add group All and default policy to client tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-27 12:11:20 +03:00
bcmmbaga
13eae9bc93 Remove unused function
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-27 11:11:18 +03:00
bcmmbaga
de99624610 Fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-26 20:15:23 +03:00
bcmmbaga
accada3311 Remove db lock on aggregate db calls
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-25 21:23:58 +03:00
bcmmbaga
71af7edd05 Refactor new account handling
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-25 17:43:39 +03:00
bcmmbaga
e17d8127e3 Remove unused store methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-25 13:49:41 +03:00
bcmmbaga
ea51ce876e Remove group all checks for accounts during startup
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-22 18:07:23 +03:00
bcmmbaga
2115e2c3f0 Add tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-22 17:53:34 +03:00
bcmmbaga
7a6ca3ee37 Fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-22 17:53:28 +03:00
bcmmbaga
70b4628b5a Refactor account settings updates
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-22 17:53:15 +03:00
bcmmbaga
f42c775e45 Add tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-22 11:56:52 +03:00
bcmmbaga
24970a1746 Refactor get and save accounts in route ops
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-22 00:46:05 +03:00
bcmmbaga
de3e67e7ae Add route store methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-22 00:45:45 +03:00
bcmmbaga
7be83a0199 Add tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-21 21:03:24 +03:00
bcmmbaga
7d0331f41e Fix prevent users from creating PATs for other users
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-21 21:03:16 +03:00
bcmmbaga
7af55fbd71 Add account locks
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-21 19:50:59 +03:00
bcmmbaga
7fa1bbc722 Fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-20 22:45:20 +03:00
bcmmbaga
66d8bbf8e2 Fix database transaction locking issue
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-20 22:45:14 +03:00
bcmmbaga
6ea98f0ce7 Remove db query context and fix get user by id
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-20 22:44:06 +03:00
bcmmbaga
6a456c52bf Refactor user and PAT handling
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-19 23:42:27 +03:00
bcmmbaga
4d00207c3b Refactor account methods and mock
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-19 23:41:22 +03:00
bcmmbaga
2de0777f7a Refactor auth middleware
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-19 23:33:46 +03:00
bcmmbaga
0ee56e14d9 fix lint
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-19 10:47:26 +03:00
bcmmbaga
20fc8e879e fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-19 00:54:07 +03:00
bcmmbaga
48edfa601f add tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-18 16:43:19 +03:00
bcmmbaga
a2a49bdd47 fix peer fields updated after save
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-18 16:43:09 +03:00
bcmmbaga
a2fb274b86 remove duplicate store method
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-18 15:09:30 +03:00
bcmmbaga
a61e9da3e9 run peer ops in transaction
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-18 15:06:25 +03:00
bcmmbaga
f6f7260897 Fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-14 19:34:05 +03:00
bcmmbaga
c557c98390 Refactor peer to use store methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-14 19:33:57 +03:00
bcmmbaga
7d849a92c0 Refactor peer handlers
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-14 19:32:34 +03:00
bcmmbaga
f5e7449d01 Add lock for peer store methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-14 19:24:51 +03:00
bcmmbaga
8420a52563 Refactor ephemeral peers
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-14 13:04:49 +03:00
bcmmbaga
6315644065 Add peer store methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-14 13:04:36 +03:00
bcmmbaga
ef55b9eccc Add tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-13 20:41:41 +03:00
bcmmbaga
218345e0ff Refactor name server groups to use store methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-13 20:41:30 +03:00
bcmmbaga
4b943c34b7 Add tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-13 13:16:32 +03:00
bcmmbaga
560190519d Refactor dns settings to use store methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-13 13:15:47 +03:00
bcmmbaga
9bc8e6e29e Merge branch 'posturechecks-get-account-refactoring' into policy-get-account-refactoring 2024-11-12 23:53:46 +03:00
bcmmbaga
9872bee41d Refactor anyGroupHasPeers to retrieve all groups once
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-12 23:53:29 +03:00
bcmmbaga
3a915decd7 Add policy tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-12 20:15:47 +03:00
bcmmbaga
50e6389a1d Merge branch 'posturechecks-get-account-refactoring' into policy-get-account-refactoring 2024-11-12 19:06:27 +03:00
bcmmbaga
bbaee18cd5 Fix typo
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-12 19:05:57 +03:00
bcmmbaga
32d1b2d602 Retrieve policy groups and posture checks once for validation
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-12 18:53:10 +03:00
bcmmbaga
2a59f04540 Merge branch 'posturechecks-get-account-refactoring' into policy-get-account-refactoring 2024-11-12 17:16:52 +03:00
bcmmbaga
446de5e2bc Merge branch 'groups-get-account-refactoring' into posturechecks-get-account-refactoring 2024-11-12 17:15:55 +03:00
bcmmbaga
147971fdfe Merge branch 'groups-get-account-refactoring' into policy-get-account-refactoring 2024-11-12 17:15:16 +03:00
bcmmbaga
ed259a6a03 Merge branch 'main' into groups-get-account-refactoring
# Conflicts:
#	management/server/account.go
#	management/server/status/error.go
2024-11-12 17:14:45 +03:00
bcmmbaga
a3abc211b3 Add tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-12 17:11:56 +03:00
bcmmbaga
00023bf110 Merge branch 'groups-get-account-refactoring' into posturechecks-get-account-refactoring 2024-11-12 15:55:34 +03:00
bcmmbaga
2806d73161 Add tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-12 13:38:34 +03:00
bcmmbaga
2d7f08c609 Fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-12 11:18:16 +03:00
bcmmbaga
0c0fd380bd Refactor policy get and save account to use store methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-12 11:17:16 +03:00
bcmmbaga
ffce48ca5f Merge branch 'groups-get-account-refactoring' into policy-get-account-refactoring 2024-11-11 23:08:34 +03:00
bcmmbaga
d23b5c892b Retrieve modified peers once for group events
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-11 22:58:22 +03:00
bcmmbaga
113c21b0e1 Change setup key log level to debug for missing group
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-11 22:57:24 +03:00
bcmmbaga
ab00c41dad fix sonar
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-11 22:38:24 +03:00
bcmmbaga
664d1388aa fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-11 22:29:59 +03:00
bcmmbaga
010a8bfdc1 Merge branch 'main' into groups-get-account-refactoring
# Conflicts:
#	management/server/group.go
#	management/server/group/group.go
#	management/server/setupkey.go
#	management/server/sql_store.go
#	management/server/status/error.go
#	management/server/store.go
2024-11-11 21:10:02 +03:00
bcmmbaga
601d429d82 fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-11 16:26:12 +03:00
bcmmbaga
d54b6967ce fix refactor
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-11 12:38:34 +03:00
bcmmbaga
174e07fefd Refactor posture checks to remove get and save account
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-11 12:37:19 +03:00
bcmmbaga
871500c5cc fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-09 01:52:09 +03:00
bcmmbaga
cc04aef7b4 Merge branch 'setupkey-get-account-refactoring' into groups-get-account-refactoring 2024-11-09 01:50:10 +03:00
bcmmbaga
3ed8b9cee9 fix missing group removed from setup key activity
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-09 01:48:28 +03:00
bcmmbaga
bdeb95c58c Run groups ops in transaction
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-09 01:17:01 +03:00
bcmmbaga
6dc185e141 Preserve store engine in SqlStore transactions
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-09 01:16:03 +03:00
bcmmbaga
7100be83cd Add AddPeer and RemovePeer methods to Group struct
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-09 01:14:30 +03:00
bcmmbaga
d58cf50127 Merge branch 'setupkey-get-account-refactoring' into groups-get-account-refactoring
# Conflicts:
#	management/server/sql_store.go
2024-11-08 19:48:13 +03:00
bcmmbaga
40af1a50e3 Merge branch 'feature/get-account-refactoring' into setupkey-get-account-refactoring
# Conflicts:
#	management/server/sql_store.go
2024-11-08 19:17:28 +03:00
bcmmbaga
ac05f69131 fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-08 18:58:19 +03:00
bcmmbaga
8126d95316 refactor GetGroupByID and add NewGroupNotFoundError
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-08 18:58:04 +03:00
bcmmbaga
0a70e4c5d4 Refactor groups to use store methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-08 18:39:36 +03:00
bcmmbaga
106fc75936 refactor account peers update
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-08 18:38:32 +03:00
bcmmbaga
f8b5eedd38 add account lock and return auto groups map on validation
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-08 10:14:13 +03:00
bcmmbaga
931521d505 get only required groups for auto-group validation
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-08 00:59:37 +03:00
bcmmbaga
1a5f3c653c add check for regular user
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-08 00:37:47 +03:00
bcmmbaga
78044c226d add lock to get account groups
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-08 00:32:14 +03:00
bcmmbaga
389c9619af Refactor setup key handling to use store methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-08 00:31:41 +03:00
50 changed files with 6537 additions and 3645 deletions

View File

@@ -31,6 +31,9 @@ INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-0
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfvprsrlo1hqoo49ohog", "cg3161rlo1hs9cq94gdg", "cg05lnblo1hkg2j514p0"]',0,'');
INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]');
INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL);
INSERT INTO installations VALUES(1,''); INSERT INTO installations VALUES(1,'');
COMMIT; COMMIT;

File diff suppressed because it is too large Load Diff

View File

@@ -401,7 +401,14 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
} }
for _, testCase := range tt { for _, testCase := range tt {
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io") store := newStore(t)
err := newAccountWithId(context.Background(), store, "account-1", userID, "netbird.io")
require.NoError(t, err, "failed to create account")
account, err := store.GetAccount(context.Background(), "account-1")
require.NoError(t, err, "failed to get account")
account.UpdateSettings(&testCase.accountSettings) account.UpdateSettings(&testCase.accountSettings)
account.Network = network account.Network = network
account.Peers = testCase.peers account.Peers = testCase.peers
@@ -419,6 +426,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil) networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
store.Close(context.Background())
} }
} }
@@ -426,27 +435,35 @@ func TestNewAccount(t *testing.T) {
domain := "netbird.io" domain := "netbird.io"
userId := "account_creator" userId := "account_creator"
accountID := "account_id" accountID := "account_id"
account := newAccountWithId(context.Background(), accountID, userId, domain)
store := newStore(t)
defer store.Close(context.Background())
err := newAccountWithId(context.Background(), store, accountID, userId, domain)
require.NoError(t, err, "failed to create account")
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "failed to get account")
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId}) verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
} }
func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { func TestAccountManager_GetOrCreateAccountIDByUser(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
} }
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if account == nil { if accountID == "" {
t.Fatalf("expected to create an account for a user %s", userID) t.Fatalf("expected to create an account for a user %s", userID)
return return
} }
account, err = manager.Store.GetAccountByUser(context.Background(), userID) account, err := manager.Store.GetAccountByUser(context.Background(), userID)
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID) t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID)
return return
@@ -669,15 +686,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id" userId := "user-id"
domain := "test.domain" domain := "test.domain"
_ = newAccountWithId(context.Background(), "", userId, domain)
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountIDByUserID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount, err := manager.Store.GetAccount(context.Background(), accountID) initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed") require.NoError(t, err, "get init account failed")
@@ -693,44 +707,53 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID) accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "get account failed") require.NoError(t, err, "failed to get account groups")
require.Len(t, account.Groups, 1, "only ALL group should exists") require.Len(t, accountGroups, 1, "only ALL group should exists")
}) })
t.Run("JWT groups enabled without claim name", func(t *testing.T) { t.Run("JWT groups enabled without claim name", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true initAccount.Settings.JWTGroupsEnabled = true
err := manager.Store.SaveAccount(context.Background(), initAccount) _, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings)
require.NoError(t, err, "save account failed") require.NoError(t, err, "failed to update account settings")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
totalAccounts, err := manager.Store.GetTotalAccounts(context.Background())
require.NoError(t, err, "failed to get total accounts")
require.Equal(t, int64(1), totalAccounts, "only one account should exist")
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID) accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "get account failed") require.NoError(t, err, "failed to get account groups")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") require.Len(t, accountGroups, 1, "if group claim is not set no group added from JWT")
}) })
t.Run("JWT groups enabled", func(t *testing.T) { t.Run("JWT groups enabled", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true initAccount.Settings.JWTGroupsEnabled = true
initAccount.Settings.JWTGroupsClaimName = "idp-groups" initAccount.Settings.JWTGroupsClaimName = "idp-groups"
err := manager.Store.SaveAccount(context.Background(), initAccount) _, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings)
require.NoError(t, err, "save account failed") require.NoError(t, err, "failed to update account settings")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
totalAccounts, err := manager.Store.GetTotalAccounts(context.Background())
require.NoError(t, err, "failed to get total accounts")
require.Equal(t, int64(1), totalAccounts, "only one account should exist")
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID) exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "get account failed") require.NoError(t, err, "failed to check account existence")
require.True(t, exists, "account should exist")
require.Len(t, account.Groups, 3, "groups should be added to the account") accountGroups, err := manager.GetAllGroups(context.Background(), accountID, userId)
require.NoError(t, err, "failed to get account groups")
require.Len(t, accountGroups, 3, "groups should be added to the account")
groupsByNames := map[string]*group.Group{} groupsByNames := map[string]*group.Group{}
for _, g := range account.Groups { for _, g := range accountGroups {
groupsByNames[g.Name] = g groupsByNames[g.Name] = g
} }
@@ -746,62 +769,55 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
}) })
} }
func TestAccountManager_GetAccountFromPAT(t *testing.T) { func TestAccountManager_GetAccountInfoFromPAT(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(context.Background(), "account_id", "testuser", "") err := newAccountWithId(context.Background(), store, "account_id", "testuser", "")
require.NoError(t, err, "failed to create account")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token)) hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &User{
Id: "someUser", userPAT := &PersonalAccessToken{
PATs: map[string]*PersonalAccessToken{ ID: "tokenId",
"tokenId": { UserID: "testuser",
ID: "tokenId", HashedToken: encodedHashedToken,
HashedToken: encodedHashedToken, CreatedAt: time.Now().UTC(),
},
},
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
} }
err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT)
require.NoError(t, err, "failed to save PAT")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
} }
account, user, pat, err := am.GetAccountFromPAT(context.Background(), token) user, pat, _, _, err := am.GetAccountInfoFromPAT(context.Background(), token)
if err != nil { if err != nil {
t.Fatalf("Error when getting Account from PAT: %s", err) t.Fatalf("Error when getting Account from PAT: %s", err)
} }
assert.Equal(t, "account_id", account.Id) assert.Equal(t, "account_id", user.AccountID)
assert.Equal(t, "someUser", user.Id) assert.Equal(t, "testuser", user.Id)
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat) assert.Equal(t, userPAT, pat)
} }
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(context.Background(), "account_id", "testuser", "") err := newAccountWithId(context.Background(), store, "account_id", "testuser", "")
require.NoError(t, err, "failed to create account")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token)) hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &User{
Id: "someUser", userPAT := &PersonalAccessToken{
PATs: map[string]*PersonalAccessToken{ ID: "tokenId",
"tokenId": { UserID: "someUser",
ID: "tokenId", HashedToken: encodedHashedToken,
HashedToken: encodedHashedToken, LastUsed: time.Time{},
LastUsed: time.Time{},
},
},
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
} }
err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT)
require.NoError(t, err, "failed to save PAT")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -812,11 +828,10 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
t.Fatalf("Error when marking PAT used: %s", err) t.Fatalf("Error when marking PAT used: %s", err)
} }
account, err = am.Store.GetAccount(context.Background(), "account_id") userPAT, err = store.GetPATByID(context.Background(), LockingStrengthShare, userPAT.UserID, userPAT.ID)
if err != nil { require.NoError(t, err, "failed to get PAT")
t.Fatalf("Error when getting account: %s", err)
} assert.True(t, !userPAT.LastUsed.IsZero())
assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero())
} }
func TestAccountManager_PrivateAccount(t *testing.T) { func TestAccountManager_PrivateAccount(t *testing.T) {
@@ -827,15 +842,15 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
} }
userId := "test_user" userId := "test_user"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if account == nil { if accountID == "" {
t.Fatalf("expected to create an account for a user %s", userId) t.Fatalf("expected to create an account for a user %s", userId)
} }
account, err = manager.Store.GetAccountByUser(context.Background(), userId) account, err := manager.Store.GetAccountByUser(context.Background(), userId)
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
} }
@@ -854,32 +869,22 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
userId := "test_user" userId := "test_user"
domain := "hotmail.com" domain := "hotmail.com"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain) accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain)
if err != nil { require.NoError(t, err, "failed to get or create account by user")
t.Fatal(err) require.NotEmptyf(t, accountID, "expected to create an account for a user %s", userId)
}
if account == nil {
t.Fatalf("expected to create an account for a user %s", userId)
}
if account != nil && account.Domain != domain { accDomain, _, err := manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID)
t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain) require.NoError(t, err, "failed to get account domain and category")
} require.Equal(t, domain, accDomain, "expected account domain to match")
domain = "gmail.com" domain = "gmail.com"
account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain) accountID, err = manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain)
if err != nil { require.NoError(t, err, "failed to get or create account by user")
t.Fatalf("got the following error while retrieving existing acc: %v", err)
}
if account == nil { accDomain, _, err = manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID)
t.Fatalf("expected to get an account for a user %s", userId) require.NoError(t, err, "failed to get account domain and category")
} require.Equal(t, domain, accDomain, "expected account domain to match")
if account != nil && account.Domain != domain {
t.Errorf("updating domain. expected %s got %s", domain, account.Domain)
}
} }
func TestAccountManager_GetAccountByUserID(t *testing.T) { func TestAccountManager_GetAccountByUserID(t *testing.T) {
@@ -911,12 +916,11 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
} }
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) { func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) {
account := newAccountWithId(context.Background(), accountID, userID, domain) err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return account, nil return am.Store.GetAccount(context.Background(), accountID)
} }
func TestAccountManager_GetAccount(t *testing.T) { func TestAccountManager_GetAccount(t *testing.T) {
@@ -1163,23 +1167,18 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return return
} }
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "netbird.cloud")
if err != nil { require.NoError(t, err, "failed to get or create account by user")
t.Fatal(err)
}
serial := account.Network.CurrentSerial() // should be 0 network, err := manager.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account network")
if account.Network.Serial != 0 { serial := network.CurrentSerial() // should be 0
t.Errorf("expecting account network to have an initial Serial=0") require.Equal(t, 0, int(serial), "expected account network to have an initial Serial=0")
return
}
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { require.NoError(t, err, "failed to generate private key")
t.Fatal(err)
return
}
expectedPeerKey := key.PublicKey().String() expectedPeerKey := key.PublicKey().String()
expectedUserID := userID expectedUserID := userID
@@ -1187,16 +1186,10 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
Key: expectedPeerKey, Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}) })
if err != nil { require.NoError(t, err, "failed to add peer")
t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy)
return
}
account, err = manager.Store.GetAccount(context.Background(), account.Id) account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil { require.NoError(t, err, "failed to get account")
t.Fatal(err)
return
}
if peer.Key != expectedPeerKey { if peer.Key != expectedPeerKey {
t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key) t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key)
@@ -1238,8 +1231,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
return return
} }
policy := Policy{ _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -1250,8 +1242,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
require.NoError(t, err) require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -1320,19 +1311,6 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
policy := Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(1) wg.Add(1)
go func() { go func() {
@@ -1345,7 +1323,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
} }
}() }()
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
})
if err != nil {
t.Errorf("delete default rule: %v", err) t.Errorf("delete default rule: %v", err)
return return
} }
@@ -1366,7 +1356,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
return return
} }
policy := Policy{ _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -1377,9 +1367,8 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
if err != nil {
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
t.Errorf("save policy: %v", err) t.Errorf("save policy: %v", err)
return return
} }
@@ -1413,13 +1402,20 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
group := group.Group{ err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
ID: "groupA", ID: "groupA",
Name: "GroupA", Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID}, Peers: []string{peer1.ID, peer2.ID, peer3.ID},
})
require.NoError(t, err, "failed to save group")
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
} }
policy := Policy{ policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -1430,14 +1426,8 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
if err != nil {
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
t.Errorf("save policy: %v", err) t.Errorf("save policy: %v", err)
return return
} }
@@ -1460,7 +1450,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
return return
} }
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil {
t.Errorf("delete group: %v", err) t.Errorf("delete group: %v", err)
return return
} }
@@ -1475,7 +1465,6 @@ func TestAccountManager_DeletePeer(t *testing.T) {
return return
} }
userID := "account_creator"
account, err := createAccount(manager, "test_account", userID, "netbird.cloud") account, err := createAccount(manager, "test_account", userID, "netbird.cloud")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -1504,7 +1493,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
return return
} }
err = manager.DeletePeer(context.Background(), account.Id, peerKey, userID) err = manager.DeletePeer(context.Background(), account.Id, peer.ID, userID)
if err != nil { if err != nil {
return return
} }
@@ -1526,7 +1515,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
assert.Equal(t, peer.Name, ev.Meta["name"]) assert.Equal(t, peer.Name, ev.Meta["name"])
assert.Equal(t, peer.FQDN(account.Domain), ev.Meta["fqdn"]) assert.Equal(t, peer.FQDN(account.Domain), ev.Meta["fqdn"])
assert.Equal(t, userID, ev.InitiatorID) assert.Equal(t, userID, ev.InitiatorID)
assert.Equal(t, peer.IP.String(), ev.TargetID) assert.Equal(t, peer.ID, ev.TargetID)
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
} }
@@ -1856,16 +1845,15 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ settings, err := manager.GetAccountSettings(context.Background(), accountID, userID)
PeerLoginExpiration: time.Hour, require.NoError(t, err, "unable to get account settings")
PeerLoginExpirationEnabled: true,
}) settings.PeerLoginExpirationEnabled = true
settings.PeerLoginExpiration = time.Hour
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
@@ -1882,11 +1870,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
// disable expiration first // disable expiration first
update := peer.Copy() update := peer.Copy()
update.LoginExpirationEnabled = false update.LoginExpirationEnabled = false
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update) _, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
require.NoError(t, err, "unable to update peer") require.NoError(t, err, "unable to update peer")
// enabling expiration should trigger the routine // enabling expiration should trigger the routine
update.LoginExpirationEnabled = true update.LoginExpirationEnabled = true
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update) _, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
require.NoError(t, err, "unable to update peer") require.NoError(t, err, "unable to update peer")
failed := waitTimeout(wg, time.Second) failed := waitTimeout(wg, time.Second)
@@ -1910,10 +1898,14 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour, settings, err := manager.GetAccountSettings(context.Background(), accountID, userID)
PeerLoginExpirationEnabled: true, require.NoError(t, err, "unable to get account settings")
})
settings.PeerLoginExpirationEnabled = true
settings.PeerLoginExpiration = time.Hour
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
@@ -1930,11 +1922,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger // when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second) failed := waitTimeout(wg, time.Second)
@@ -1965,7 +1954,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
account, err := manager.Store.GetAccount(context.Background(), accountID) account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
@@ -1978,11 +1967,15 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
wg.Done() wg.Done()
}, },
} }
// enabling PeerLoginExpirationEnabled should trigger the expiration job // enabling PeerLoginExpirationEnabled should trigger the expiration job
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ settings, err := manager.GetAccountSettings(context.Background(), accountID, userID)
PeerLoginExpiration: time.Hour, require.NoError(t, err, "unable to get account settings")
PeerLoginExpirationEnabled: true,
}) settings.PeerLoginExpirationEnabled = true
settings.PeerLoginExpiration = time.Hour
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, settings)
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
failed := waitTimeout(wg, time.Second) failed := waitTimeout(wg, time.Second)
@@ -1992,10 +1985,8 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
wg.Add(1) wg.Add(1)
// disabling PeerLoginExpirationEnabled should trigger cancel // disabling PeerLoginExpirationEnabled should trigger cancel
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ settings.PeerLoginExpirationEnabled = false
PeerLoginExpiration: time.Hour, _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, settings)
PeerLoginExpirationEnabled: false,
})
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
failed = waitTimeout(wg, time.Second) failed = waitTimeout(wg, time.Second)
if failed { if failed {
@@ -2010,30 +2001,29 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ settings, err := manager.GetAccountSettings(context.Background(), accountID, userID)
PeerLoginExpiration: time.Hour, require.NoError(t, err, "unable to get account settings")
PeerLoginExpirationEnabled: false,
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) settings.PeerLoginExpirationEnabled = false
settings.PeerLoginExpiration = time.Hour
updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
require.NoError(t, err, "expecting to update account settings successfully but got error")
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
settings, err = manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
assert.False(t, settings.PeerLoginExpirationEnabled) assert.False(t, settings.PeerLoginExpirationEnabled)
assert.Equal(t, settings.PeerLoginExpiration, time.Hour) assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ settings.PeerLoginExpiration = time.Second
PeerLoginExpiration: time.Second, _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
PeerLoginExpirationEnabled: false,
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ settings.PeerLoginExpiration = time.Hour * 24 * 181
PeerLoginExpiration: time.Hour * 24 * 181, _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
PeerLoginExpirationEnabled: false,
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days") require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
} }
@@ -2714,7 +2704,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0) assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })
@@ -2734,7 +2724,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1) assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })

View File

@@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"slices"
"strconv" "strconv"
"sync" "sync"
@@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings") return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
} }
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
@@ -94,64 +99,137 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
// SaveDNSSettings validates a user role and updates the account's DNS settings // SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
}
if dnsSettingsToSave == nil { if dnsSettingsToSave == nil {
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
} }
if len(dnsSettingsToSave.DisabledManagementGroups) != 0 { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups) if err != nil {
if err != nil {
return err
}
}
oldSettings := account.DNSSettings.Copy()
account.DNSSettings = dnsSettingsToSave.Copy()
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
for _, id := range addedGroups { if user.AccountID != accountID {
group := account.GetGroup(id) return status.NewUserNotPartOfAccountError()
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
} }
for _, id := range removedGroups { if !user.HasAdminPower() {
group := account.GetGroup(id) return status.NewAdminPermissionError()
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
} }
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { var updateAccountPeers bool
am.updateAccountPeers(ctx, account) var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
return err
}
oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return err
}
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
if err != nil {
return err
}
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...)
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave)
})
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
} }
// prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
return nil
}
for _, groupID := range addedGroups {
group, ok := groups[groupID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID)
continue
}
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
})
}
for _, groupID := range removedGroups {
group, ok := groups[groupID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID)
continue
}
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
})
}
return eventsToStore
}
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeers(ctx, transaction, accountID, removedGroups)
}
// validateDNSSettings validates the DNS settings.
func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error {
if len(settings.DisabledManagementGroups) == 0 {
return nil
}
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups)
if err != nil {
return err
}
return validateGroups(settings.DisabledManagementGroups, groups)
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig { func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{ protoUpdate := &proto.DNSConfig{

View File

@@ -39,12 +39,12 @@ func TestGetDNSSettings(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestDNSAccount(t, am) accountID, err := initTestDNSAccount(t, am)
if err != nil { if err != nil {
t.Fatal("failed to init testing account") t.Fatal("failed to init testing account")
} }
dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID) dnsSettings, err := am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID)
if err != nil { if err != nil {
t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
} }
@@ -53,16 +53,12 @@ func TestGetDNSSettings(t *testing.T) {
t.Fatal("DNS settings for new accounts shouldn't return nil") t.Fatal("DNS settings for new accounts shouldn't return nil")
} }
account.DNSSettings = DNSSettings{ err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &DNSSettings{
DisabledManagementGroups: []string{group1ID}, DisabledManagementGroups: []string{group1ID},
} })
require.NoError(t, err, "failed to update DNS settings")
err = am.Store.SaveAccount(context.Background(), account) dnsSettings, err = am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID)
if err != nil {
t.Error("failed to save testing account with new DNS settings")
}
dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
if err != nil { if err != nil {
t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
} }
@@ -71,7 +67,7 @@ func TestGetDNSSettings(t *testing.T) {
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups) t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
} }
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID) _, err = am.GetDNSSettings(context.Background(), accountID, dnsRegularUserID)
if err == nil { if err == nil {
t.Errorf("An error should be returned when getting the DNS settings with a regular user") t.Errorf("An error should be returned when getting the DNS settings with a regular user")
} }
@@ -126,12 +122,12 @@ func TestSaveDNSSettings(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestDNSAccount(t, am) accountID, err := initTestDNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings) err = am.SaveDNSSettings(context.Background(), accountID, testCase.userID, testCase.inputSettings)
if err != nil { if err != nil {
if testCase.shouldFail { if testCase.shouldFail {
return return
@@ -139,7 +135,7 @@ func TestSaveDNSSettings(t *testing.T) {
t.Error(err) t.Error(err)
} }
updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id) updatedAccount, err := am.Store.GetAccount(context.Background(), accountID)
if err != nil { if err != nil {
t.Errorf("should be able to retrieve updated account, got err: %s", err) t.Errorf("should be able to retrieve updated account, got err: %s", err)
} }
@@ -158,17 +154,17 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestDNSAccount(t, am) accountID, err := initTestDNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
peer1, err := account.FindPeerByPubKey(dnsPeer1Key) peer1, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer1Key)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
peer2, err := account.FindPeerByPubKey(dnsPeer2Key) peer2, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer2Key)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
@@ -179,11 +175,13 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled") require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group") require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group")
dnsSettings := account.DNSSettings.Copy() accountDNSSettings, err := am.Store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account DNS settings")
dnsSettings := accountDNSSettings.Copy()
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID) dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
account.DNSSettings = dnsSettings err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &dnsSettings)
err = am.Store.SaveAccount(context.Background(), account) require.NoError(t, err, "failed to update DNS settings")
require.NoError(t, err)
updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID) updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
require.NoError(t, err) require.NoError(t, err)
@@ -222,7 +220,7 @@ func createDNSStore(t *testing.T) (Store, error) {
return store, nil return store, nil
} }
func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
t.Helper() t.Helper()
peer1 := &nbpeer.Peer{ peer1 := &nbpeer.Peer{
Key: dnsPeer1Key, Key: dnsPeer1Key,
@@ -257,64 +255,65 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
domain := "example.com" domain := "example.com"
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain) err := newAccountWithId(context.Background(), am.Store, dnsAccountID, dnsAdminUserID, domain)
if err != nil {
account.Users[dnsRegularUserID] = &User{ return "", err
Id: dnsRegularUserID,
Role: UserRoleUser,
} }
err := am.Store.SaveAccount(context.Background(), account) err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: dnsRegularUserID,
AccountID: dnsAccountID,
Role: UserRoleUser,
})
if err != nil { if err != nil {
return nil, err return "", err
} }
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1) savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1)
if err != nil { if err != nil {
return nil, err return "", err
} }
_, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2) _, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2)
if err != nil { if err != nil {
return nil, err return "", err
} }
account, err = am.Store.GetAccount(context.Background(), account.Id) peer1, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer1.Key)
if err != nil { if err != nil {
return nil, err return "", err
} }
peer1, err = account.FindPeerByPubKey(peer1.Key) _, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer2.Key)
if err != nil { if err != nil {
return nil, err return "", err
} }
_, err = account.FindPeerByPubKey(peer2.Key) err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{
{
ID: dnsGroup1ID,
AccountID: dnsAccountID,
Peers: []string{peer1.ID},
Name: dnsGroup1ID,
},
{
ID: dnsGroup2ID,
AccountID: dnsAccountID,
Name: dnsGroup2ID,
},
})
if err != nil { if err != nil {
return nil, err return "", err
} }
newGroup1 := &group.Group{ allGroup, err := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, dnsAccountID, "All")
ID: dnsGroup1ID,
Peers: []string{peer1.ID},
Name: dnsGroup1ID,
}
newGroup2 := &group.Group{
ID: dnsGroup2ID,
Name: dnsGroup2ID,
}
account.Groups[newGroup1.ID] = newGroup1
account.Groups[newGroup2.ID] = newGroup2
allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
return nil, err return "", err
} }
account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{ err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &dns.NameServerGroup{
ID: dnsNSGroup1, ID: dnsNSGroup1,
Name: "ns-group-1", AccountID: dnsAccountID,
Name: "ns-group-1",
NameServers: []dns.NameServer{{ NameServers: []dns.NameServer{{
IP: netip.MustParseAddr(savedPeer1.IP.String()), IP: netip.MustParseAddr(savedPeer1.IP.String()),
NSType: dns.UDPNameServerType, NSType: dns.UDPNameServerType,
@@ -323,14 +322,12 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
Primary: true, Primary: true,
Enabled: true, Enabled: true,
Groups: []string{allGroup.ID}, Groups: []string{allGroup.ID},
} })
err = am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return "", err
} }
return am.Store.GetAccount(context.Background(), account.Id) return dnsAccountID, nil
} }
func generateTestData(size int) nbdns.Config { func generateTestData(size int) nbdns.Config {

View File

@@ -20,10 +20,10 @@ var (
) )
type ephemeralPeer struct { type ephemeralPeer struct {
id string id string
account *Account accountID string
deadline time.Time deadline time.Time
next *ephemeralPeer next *ephemeralPeer
} }
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it // todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
@@ -104,12 +104,6 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID) log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err)
return
}
e.peersLock.Lock() e.peersLock.Lock()
defer e.peersLock.Unlock() defer e.peersLock.Unlock()
@@ -117,7 +111,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
return return
} }
e.addPeer(peer.ID, a, newDeadLine()) e.addPeer(peer.AccountID, peer.ID, newDeadLine())
if e.timer == nil { if e.timer == nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
e.cleanup(ctx) e.cleanup(ctx)
@@ -126,17 +120,21 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
} }
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
accounts := e.store.GetAllAccounts(context.Background()) peers, err := e.store.GetAllEphemeralPeers(ctx, LockingStrengthShare)
if err != nil {
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
return
}
t := newDeadLine() t := newDeadLine()
count := 0 count := 0
for _, a := range accounts { for _, p := range peers {
for id, p := range a.Peers { if p.Ephemeral {
if p.Ephemeral { count++
count++ e.addPeer(p.AccountID, p.ID, t)
e.addPeer(id, a, t)
}
} }
} }
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count) log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count)
} }
@@ -170,18 +168,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
for id, p := range deletePeers { for id, p := range deletePeers {
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator) err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
} }
} }
} }
func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) { func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
ep := &ephemeralPeer{ ep := &ephemeralPeer{
id: id, id: peerID,
account: account, accountID: accountID,
deadline: deadline, deadline: deadline,
} }
if e.headPeer == nil { if e.headPeer == nil {

View File

@@ -7,25 +7,12 @@ import (
"time" "time"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/stretchr/testify/require"
) )
type MockStore struct { type MockStore struct {
Store Store
account *Account accountID string
}
func (s *MockStore) GetAllAccounts(_ context.Context) []*Account {
return []*Account{s.account}
}
func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) {
_, ok := s.account.Peers[peerId]
if ok {
return s.account, nil
}
return nil, status.NewPeerNotFoundError(peerId)
} }
type MocAccountManager struct { type MocAccountManager struct {
@@ -33,9 +20,8 @@ type MocAccountManager struct {
store *MockStore store *MockStore
} }
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error { func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, _ string) error {
delete(a.store.account.Peers, peerID) return a.store.DeletePeer(context.Background(), LockingStrengthUpdate, accountID, peerID)
return nil //nolint:nil
} }
func TestNewManager(t *testing.T) { func TestNewManager(t *testing.T) {
@@ -44,23 +30,26 @@ func TestNewManager(t *testing.T) {
return startTime return startTime
} }
store := &MockStore{} store := &MockStore{
Store: newStore(t),
}
am := MocAccountManager{ am := MocAccountManager{
store: store, store: store,
} }
numberOfPeers := 5 numberOfPeers := 5
numberOfEphemeralPeers := 3 numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers) err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
require.NoError(t, err, "failed to seed peers")
mgr := NewEphemeralManager(store, am) mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers(context.Background()) mgr.loadEphemeralPeers(context.Background())
startTime = startTime.Add(ephemeralLifeTime + 1) startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup(context.Background()) mgr.cleanup(context.Background())
if len(store.account.Peers) != numberOfPeers { peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers)) require.NoError(t, err, "failed to get account peers")
} require.Equal(t, numberOfPeers, len(peers), "failed to cleanup ephemeral peers")
} }
func TestNewManagerPeerConnected(t *testing.T) { func TestNewManagerPeerConnected(t *testing.T) {
@@ -69,26 +58,32 @@ func TestNewManagerPeerConnected(t *testing.T) {
return startTime return startTime
} }
store := &MockStore{} store := &MockStore{
Store: newStore(t),
}
am := MocAccountManager{ am := MocAccountManager{
store: store, store: store,
} }
numberOfPeers := 5 numberOfPeers := 5
numberOfEphemeralPeers := 3 numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers) err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
require.NoError(t, err, "failed to seed peers")
mgr := NewEphemeralManager(store, am) mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers(context.Background()) mgr.loadEphemeralPeers(context.Background())
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0")
require.NoError(t, err, "failed to get peer")
mgr.OnPeerConnected(context.Background(), peer)
startTime = startTime.Add(ephemeralLifeTime + 1) startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup(context.Background()) mgr.cleanup(context.Background())
expected := numberOfPeers + 1 peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
if len(store.account.Peers) != expected { require.NoError(t, err, "failed to get account peers")
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers)) require.Equal(t, numberOfPeers+1, len(peers), "failed to cleanup ephemeral peers")
}
} }
func TestNewManagerPeerDisconnected(t *testing.T) { func TestNewManagerPeerDisconnected(t *testing.T) {
@@ -97,50 +92,73 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
return startTime return startTime
} }
store := &MockStore{} store := &MockStore{
Store: newStore(t),
}
am := MocAccountManager{ am := MocAccountManager{
store: store, store: store,
} }
numberOfPeers := 5 numberOfPeers := 5
numberOfEphemeralPeers := 3 numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers) err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
require.NoError(t, err, "failed to seed peers")
mgr := NewEphemeralManager(store, am) mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers(context.Background()) mgr.loadEphemeralPeers(context.Background())
for _, v := range store.account.Peers {
mgr.OnPeerConnected(context.Background(), v)
peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
require.NoError(t, err, "failed to get account peers")
for _, v := range peers {
mgr.OnPeerConnected(context.Background(), v)
} }
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0")
require.NoError(t, err, "failed to get peer")
mgr.OnPeerDisconnected(context.Background(), peer)
startTime = startTime.Add(ephemeralLifeTime + 1) startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup(context.Background()) mgr.cleanup(context.Background())
peers, err = store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
require.NoError(t, err, "failed to get account peers")
expected := numberOfPeers + numberOfEphemeralPeers - 1 expected := numberOfPeers + numberOfEphemeralPeers - 1
if len(store.account.Peers) != expected { require.Equal(t, expected, len(peers), "failed to cleanup ephemeral peers")
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
}
} }
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) { func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) error {
store.account = newAccountWithId(context.Background(), "my account", "", "") accountID := "my account"
err := newAccountWithId(context.Background(), store, accountID, "", "")
if err != nil {
return err
}
store.accountID = accountID
for i := 0; i < numberOfPeers; i++ { for i := 0; i < numberOfPeers; i++ {
peerId := fmt.Sprintf("peer_%d", i) peerId := fmt.Sprintf("peer_%d", i)
p := &nbpeer.Peer{ p := &nbpeer.Peer{
ID: peerId, ID: peerId,
AccountID: accountID,
Ephemeral: false, Ephemeral: false,
} }
store.account.Peers[p.ID] = p err = store.AddPeerToAccount(context.Background(), p)
if err != nil {
return err
}
} }
for i := 0; i < numberOfEphemeralPeers; i++ { for i := 0; i < numberOfEphemeralPeers; i++ {
peerId := fmt.Sprintf("ephemeral_peer_%d", i) peerId := fmt.Sprintf("ephemeral_peer_%d", i)
p := &nbpeer.Peer{ p := &nbpeer.Peer{
ID: peerId, ID: peerId,
AccountID: accountID,
Ephemeral: true, Ephemeral: true,
} }
store.account.Peers[p.ID] = p err = store.AddPeerToAccount(context.Background(), p)
if err != nil {
return err
}
} }
return nil
} }

View File

@@ -37,8 +37,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
return err return err
} }
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID { if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "groups are blocked for users") return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() && settings.RegularUsersViewBlocked {
return status.NewAdminPermissionError()
} }
return nil return nil
@@ -49,8 +53,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
} }
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
@@ -58,13 +61,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
} }
// SaveGroup object of the peers // SaveGroup object of the peers
@@ -77,79 +79,74 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
// SaveGroups adds new groups to the account. // SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock. // Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method. // It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var eventsToStore []func() var eventsToStore []func()
var groupsToSave []*nbgroup.Group
var updateAccountPeers bool
for _, newGroup := range newGroups { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { groupIDs := make([]string, 0, len(groups))
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) for _, newGroup := range groups {
} if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
}
} }
// Avoid duplicate groups only for the API issued groups. newGroup.AccountID = accountID
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. groupsToSave = append(groupsToSave, newGroup)
if existingGroup != nil { groupIDs = append(groupIDs, newGroup.ID)
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String() events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
} }
for _, peerID := range newGroup.Peers { updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
if account.Peers[peerID] == nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) return err
}
} }
oldGroup := account.Groups[newGroup.ID] if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
account.Groups[newGroup.ID] = newGroup return err
}
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account) return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
eventsToStore = append(eventsToStore, events...) })
} if err != nil {
newGroupIDs := make([]string, 0, len(newGroups))
for _, newGroup := range newGroups {
newGroupIDs = append(newGroupIDs, newGroup.ID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
if areGroupChangesAffectPeers(account, newGroupIDs) {
am.updateAccountPeers(ctx, account)
}
for _, storeEvent := range eventsToStore { for _, storeEvent := range eventsToStore {
storeEvent() storeEvent()
} }
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil return nil
} }
// prepareGroupEvents prepares a list of event functions to be stored. // prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() { func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
var eventsToStore []func() var eventsToStore []func()
addedPeers := make([]string, 0) addedPeers := make([]string, 0)
removedPeers := make([]string, 0) removedPeers := make([]string, 0)
if oldGroup != nil { oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
if err == nil && oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers) addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else { } else {
@@ -159,35 +156,42 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
}) })
} }
for _, p := range addedPeers { modifiedPeers := slices.Concat(addedPeers, removedPeers)
peer := account.Peers[p] peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers)
if peer == nil { if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
return nil
}
for _, peerID := range addedPeers {
peer, ok := peers[peerID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID)
continue continue
} }
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, meta := map[string]any{
map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID,
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), }
}) am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta)
}) })
} }
for _, p := range removedPeers { for _, peerID := range removedPeers {
peer := account.Peers[p] peer, ok := peers[peerID]
if peer == nil { if !ok {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID)
continue continue
} }
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, meta := map[string]any{
map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID,
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), }
}) am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta)
}) })
} }
@@ -210,40 +214,47 @@ func difference(a, b []string) []string {
} }
// DeleteGroup object of the peers. // DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID] if user.AccountID != accountID {
if !ok { return status.NewUserNotPartOfAccountError()
return nil
} }
allGroup, err := account.GetGroupAll() if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var group *nbgroup.Group
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
if group.IsGroupAll() {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if err = validateDeleteGroup(ctx, transaction, group, userID); err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID)
})
if err != nil { if err != nil {
return err return err
} }
if allGroup.ID == groupID { am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta())
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if err = validateDeleteGroup(account, group, userId); err != nil {
return err
}
delete(account.Groups, groupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
return nil return nil
} }
@@ -254,93 +265,90 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
// //
// If an error occurs while deleting a group, the function skips it and continues deleting other groups. // If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end. // Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error { func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountId) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
var allErrors error if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
for _, groupID := range groupIDs {
group, ok := account.Groups[groupID]
if !ok {
continue
}
if err := validateDeleteGroup(account, group, userId); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
continue
}
delete(account.Groups, groupID)
deletedGroups = append(deletedGroups, group)
} }
account.Network.IncSerial() if user.IsRegularUser() {
if err = am.Store.SaveAccount(ctx, account); err != nil { return status.NewAdminPermissionError()
}
var allErrors error
var groupIDsToDelete []string
var deletedGroups []*nbgroup.Group
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
continue
}
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
continue
}
groupIDsToDelete = append(groupIDsToDelete, groupID)
deletedGroups = append(deletedGroups, group)
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
})
if err != nil {
return err return err
} }
for _, g := range deletedGroups { for _, group := range deletedGroups {
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
} }
return allErrors return allErrors
} }
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
}
// GroupAddPeer appends peer to the group // GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) var group *nbgroup.Group
defer unlock() var updateAccountPeers bool
var err error
account, err := am.Store.GetAccount(ctx, accountID) err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
if updated := group.AddPeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
})
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID] if updateAccountPeers {
if !ok { am.updateAccountPeers(ctx, accountID)
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
add := true
for _, itemID := range group.Peers {
if itemID == peerID {
add = false
break
}
}
if add {
group.Peers = append(group.Peers, peerID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
} }
return nil return nil
@@ -348,41 +356,80 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupDeletePeer removes peer from the group // GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) var group *nbgroup.Group
defer unlock() var updateAccountPeers bool
var err error
account, err := am.Store.GetAccount(ctx, accountID) err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
if updated := group.RemovePeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
})
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID] if updateAccountPeers {
if !ok { am.updateAccountPeers(ctx, accountID)
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
account.Network.IncSerial()
for i, itemID := range group.Peers {
if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(ctx, account); err != nil {
return err
}
}
}
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
} }
return nil return nil
} }
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error { // validateNewGroup validates the new group for existence and required fields.
func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
if err != nil {
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
return err
}
}
// Prevent duplicate groups for API-issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
return nil
}
func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user // disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration { if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser := account.Users[userID] executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
if executingUser == nil { if err != nil {
return status.Errorf(status.NotFound, "user not found") return status.Errorf(status.NotFound, "user not found")
} }
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
@@ -390,51 +437,77 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string)
} }
} }
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked { if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)} return &GroupLinkError{"route", string(linkedRoute.NetID)}
} }
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked { if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name} return &GroupLinkError{"name server groups", linkedDns.Name}
} }
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked { if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name} return &GroupLinkError{"policy", linkedPolicy.Name}
} }
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked { if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name} return &GroupLinkError{"setup key", linkedSetupKey.Name}
} }
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked { if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id} return &GroupLinkError{"user", linkedUser.Id}
} }
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) { return checkGroupLinkedToSettings(ctx, transaction, group)
}
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name} return &GroupLinkError{"disabled DNS management groups", group.Name}
} }
if account.Settings.Extra != nil { settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) { if err != nil {
return &GroupLinkError{"integrated validator", group.Name} return err
} }
if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
} }
return nil return nil
} }
// isGroupLinkedToRoute checks if a group is linked to any route in the account. // isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
}
for _, r := range routes { for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r return true, r
} }
} }
return false, nil return false, nil
} }
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. // isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
}
for _, policy := range policies { for _, policy := range policies {
for _, rule := range policy.Rules { for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
@@ -446,7 +519,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
} }
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
}
for _, dns := range nameServerGroups { for _, dns := range nameServerGroups {
for _, g := range dns.Groups { for _, g := range dns.Groups {
if g == groupID { if g == groupID {
@@ -454,11 +533,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
} }
} }
} }
return false, nil return false, nil
} }
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) { func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
}
for _, setupKey := range setupKeys { for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) { if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey return true, setupKey
@@ -468,7 +554,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
} }
// isGroupLinkedToUser checks if a group is linked to any user in the account. // isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
}
for _, user := range users { for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) { if slices.Contains(user.AutoGroups, groupID) {
return true, user return true, user
@@ -477,31 +569,47 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
return false, nil return false, nil
} }
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return false, err
}
for _, groupID := range groupIDs {
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
return true, nil
}
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
return true, nil
}
}
return false, nil
}
// anyGroupHasPeers checks if any of the given groups in the account have peers. // anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool { func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
for _, groupID := range groupIDs { groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs)
if group, exists := account.Groups[groupID]; exists && group.HasPeers() { if err != nil {
return true return false, err
}
} }
return false
}
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool { for _, group := range groups {
for _, groupID := range groupIDs { if group.HasPeers() {
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) { return true, nil
return true
}
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
return true
} }
} }
return false return false, nil
} }

View File

@@ -54,3 +54,30 @@ func (g *Group) HasPeers() bool {
func (g *Group) IsGroupAll() bool { func (g *Group) IsGroupAll() bool {
return g.Name == "All" return g.Name == "All"
} }
// AddPeer adds peerID to Peers if not present, returning true if added.
func (g *Group) AddPeer(peerID string) bool {
if peerID == "" {
return false
}
for _, itemID := range g.Peers {
if itemID == peerID {
return false
}
}
g.Peers = append(g.Peers, peerID)
return true
}
// RemovePeer removes peerID from Peers if present, returning true if removed.
func (g *Group) RemovePeer(peerID string) bool {
for i, itemID := range g.Peers {
if itemID == peerID {
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
return true
}
}
return false
}

View File

@@ -0,0 +1,90 @@
package group
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAddPeer(t *testing.T) {
t.Run("add new peer to empty slice", func(t *testing.T) {
group := &Group{Peers: []string{}}
peerID := "peer1"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add new peer to nil slice", func(t *testing.T) {
group := &Group{Peers: nil}
peerID := "peer1"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add new peer to non-empty slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer3"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add duplicate peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer1"
assert.False(t, group.AddPeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
t.Run("add empty peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := ""
assert.False(t, group.AddPeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
}
func TestRemovePeer(t *testing.T) {
t.Run("remove existing peer from slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2", "peer3"}}
peerID := "peer2"
assert.True(t, group.RemovePeer(peerID))
assert.NotContains(t, group.Peers, peerID)
assert.Equal(t, 2, len(group.Peers))
})
t.Run("remove peer from empty slice", func(t *testing.T) {
group := &Group{Peers: []string{}}
peerID := "peer1"
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 0, len(group.Peers))
})
t.Run("remove peer from nil slice", func(t *testing.T) {
group := &Group{Peers: nil}
peerID := "peer1"
assert.False(t, group.RemovePeer(peerID))
assert.Nil(t, group.Peers)
})
t.Run("remove non-existent peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer3"
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
t.Run("remove peer from single-item slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1"}}
peerID := "peer1"
assert.True(t, group.RemovePeer(peerID))
assert.Equal(t, 0, len(group.Peers))
assert.NotContains(t, group.Peers, peerID)
})
t.Run("remove empty peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := ""
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
}

View File

@@ -328,25 +328,30 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
} }
routeResource := &route.Route{ routeResource := &route.Route{
ID: "example route", ID: "example route",
Groups: []string{groupForRoute.ID}, AccountID: accountID,
Groups: []string{groupForRoute.ID},
} }
routePeerGroupResource := &route.Route{ routePeerGroupResource := &route.Route{
ID: "example route with peer groups", ID: "example route with peer groups",
AccountID: accountID,
PeerGroups: []string{groupForRoute2.ID}, PeerGroups: []string{groupForRoute2.ID},
} }
nameServerGroup := &nbdns.NameServerGroup{ nameServerGroup := &nbdns.NameServerGroup{
ID: "example name server group", ID: "example name server group",
Groups: []string{groupForNameServerGroups.ID}, AccountID: accountID,
Groups: []string{groupForNameServerGroups.ID},
} }
policy := &Policy{ policy := &Policy{
ID: "example policy", ID: "example policy",
AccountID: accountID,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: "example policy rule", ID: "example policy rule",
PolicyID: "example policy",
Destinations: []string{groupForPolicies.ID}, Destinations: []string{groupForPolicies.ID},
}, },
}, },
@@ -354,35 +359,60 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
setupKey := &SetupKey{ setupKey := &SetupKey{
Id: "example setup key", Id: "example setup key",
AccountID: accountID,
AutoGroups: []string{groupForSetupKeys.ID}, AutoGroups: []string{groupForSetupKeys.ID},
} }
user := &User{ user := &User{
Id: "example user", Id: "example user",
AccountID: accountID,
AutoGroups: []string{groupForUsers.ID}, AutoGroups: []string{groupForUsers.ID},
} }
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
account.Routes[routeResource.ID] = routeResource
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
account.Policies = append(account.Policies, policy)
account.SetupKeys[setupKey.Id] = setupKey
account.Users[user.Id] = user
err := am.Store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute) err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routeResource)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2) if err != nil {
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups) return nil, nil, err
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies) }
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
acc, err := am.Store.GetAccount(context.Background(), account.Id) err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routePeerGroupResource)
if err != nil {
return nil, nil, err
}
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nameServerGroup)
if err != nil {
return nil, nil, err
}
err = am.Store.CreatePolicy(context.Background(), LockingStrengthUpdate, policy)
if err != nil {
return nil, nil, err
}
err = am.Store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
if err != nil {
return nil, nil, err
}
err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, user)
if err != nil {
return nil, nil, err
}
err = am.SaveGroups(context.Background(), accountID, groupAdminUserID, []*nbgroup.Group{
groupForRoute, groupForRoute2, groupForNameServerGroups, groupForPolicies,
groupForSetupKeys, groupForUsers, groupForIntegration,
})
if err != nil {
return nil, nil, err
}
acc, err := am.Store.GetAccount(context.Background(), accountID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -500,8 +530,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}) })
// adding a group to policy // adding a group to policy
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -512,7 +541,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
}, false) })
assert.NoError(t, err) assert.NoError(t, err)
// Saving a group linked to policy should update account peers and send peer update // Saving a group linked to policy should update account peers and send peer update

View File

@@ -100,13 +100,13 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
} }
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings) resp := toAccountResponse(accountID, updatedSettings)
util.WriteJSONObject(r.Context(), w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }

View File

@@ -29,7 +29,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) { GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
return account.Settings, nil return account.Settings, nil
}, },
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) {
halfYearLimit := 180 * 24 * time.Hour halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit { if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@@ -39,9 +39,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
} }
accCopy := account.Copy() return newSettings.Copy(), nil
accCopy.UpdateSettings(newSettings)
return accCopy, nil
}, },
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@@ -47,7 +47,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
) )
authMiddleware := middleware.NewAuthMiddleware( authMiddleware := middleware.NewAuthMiddleware(
accountManager.GetAccountFromPAT, accountManager.GetAccountInfoFromPAT,
jwtValidator.ValidateAndParse, jwtValidator.ValidateAndParse,
accountManager.MarkPATUsed, accountManager.MarkPATUsed,
accountManager.CheckUserAccessByJWTGroups, accountManager.CheckUserAccessByJWTGroups,

View File

@@ -19,8 +19,8 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
// GetAccountFromPATFunc function // GetAccountInfoFromPATFunc function
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error)
// ValidateAndParseTokenFunc function // ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error) type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
@@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct { type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc getAccountInfoFromPAT GetAccountInfoFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc markPATUsed MarkPATUsedFunc
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
@@ -47,7 +47,7 @@ const (
) )
// NewAuthMiddleware instance constructor // NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor, markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
audience string, userIdClaim string) *AuthMiddleware { audience string, userIdClaim string) *AuthMiddleware {
if userIdClaim == "" { if userIdClaim == "" {
@@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
} }
return &AuthMiddleware{ return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT, getAccountInfoFromPAT: getAccountInfoFromPAT,
validateAndParseToken: validateAndParseToken, validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed, markPATUsed: markPATUsed,
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups, checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
@@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
// CheckPATFromRequest checks if the PAT is valid // CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
token, err := getTokenFromPATRequest(auth) token, err := getTokenFromPATRequest(auth)
// If an error occurs, call the error handler and return an error
if err != nil { if err != nil {
return fmt.Errorf("Error extracting token: %w", err) return fmt.Errorf("error extracting token: %w", err)
} }
account, user, pat, err := m.getAccountFromPAT(r.Context(), token) user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
if err != nil { if err != nil {
return fmt.Errorf("invalid Token: %w", err) return fmt.Errorf("invalid Token: %w", err)
} }
@@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
claimMaps := jwt.MapClaims{} claimMaps := jwt.MapClaims{}
claimMaps[m.userIDClaim] = user.Id claimMaps[m.userIDClaim] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information. // Update the current request with the new context information.

View File

@@ -33,7 +33,8 @@ var testAccount = &server.Account{
Domain: domain, Domain: domain,
Users: map[string]*server.User{ Users: map[string]*server.User{
userID: { userID: {
Id: userID, Id: userID,
AccountID: accountID,
PATs: map[string]*server.PersonalAccessToken{ PATs: map[string]*server.PersonalAccessToken{
tokenID: { tokenID: {
ID: tokenID, ID: tokenID,
@@ -49,11 +50,11 @@ var testAccount = &server.Account{
}, },
} }
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error) {
if token == PAT { if token == PAT {
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
} }
return nil, nil, nil, fmt.Errorf("PAT invalid") return nil, nil, "", "", fmt.Errorf("PAT invalid")
} }
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) { func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
@@ -165,7 +166,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
) )
authMiddleware := NewAuthMiddleware( authMiddleware := NewAuthMiddleware(
mockGetAccountFromPAT, mockGetAccountInfoFromPAT,
mockValidateAndParseToken, mockValidateAndParseToken,
mockMarkPATUsed, mockMarkPATUsed,
mockCheckUserAccessByJWTGroups, mockCheckUserAccessByJWTGroups,

View File

@@ -48,8 +48,8 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error)
return peerToReturn, nil return peerToReturn, nil
} }
func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(ctx, err, w)
return return
@@ -62,11 +62,16 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
} }
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
groupsInfo := toGroupsInfo(account.Groups, peer.ID) peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) util.WriteError(ctx, err, w)
return
}
groupsInfo := toGroupsInfo(peerGroups)
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w) util.WriteError(ctx, fmt.Errorf("internal error"), w)
return return
} }
@@ -75,7 +80,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
} }
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{} req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
@@ -99,16 +104,21 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
} }
} }
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update) peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(ctx, err, w)
return return
} }
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
if err != nil {
util.WriteError(ctx, err, w)
return
}
groupMinimumInfo := toGroupsInfo(peerGroups)
validPeers, err := h.accountManager.GetValidatedPeers(account) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w) util.WriteError(ctx, fmt.Errorf("internal error"), w)
@@ -149,18 +159,11 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
case http.MethodDelete: case http.MethodDelete:
h.deletePeer(r.Context(), accountID, userID, peerID, w) h.deletePeer(r.Context(), accountID, userID, peerID, w)
return return
case http.MethodGet, http.MethodPut: case http.MethodGet:
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) h.getPeer(r.Context(), accountID, peerID, userID, w)
if err != nil { return
util.WriteError(r.Context(), err, w) case http.MethodPut:
return h.updatePeer(r.Context(), accountID, userID, peerID, w, r)
}
if r.Method == http.MethodGet {
h.getPeer(r.Context(), account, peerID, userID, w)
} else {
h.updatePeer(r.Context(), account, userID, peerID, w, r)
}
return return
default: default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
@@ -176,7 +179,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
return return
} }
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -184,19 +187,25 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
respBody := make([]*api.PeerBatch, 0, len(account.Peers)) respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range account.Peers { for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer) peerToReturn, err := h.checkPeerStatus(peer)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), accountID, peer.ID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupMinimumInfo := toGroupsInfo(peerGroups)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
} }
validPeersMap, err := h.accountManager.GetValidatedPeers(account) validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil { if err != nil {
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w) util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
@@ -259,16 +268,16 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
} }
} }
dnsDomain := h.accountManager.GetDNSDomain() validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil { if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w) util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return return
} }
customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain()) dnsDomain := h.accountManager.GetDNSDomain()
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil)
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
@@ -303,26 +312,14 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
} }
} }
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { func toGroupsInfo(groups []*nbgroup.Group) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum groupsInfo := make([]api.GroupMinimum, 0, len(groups))
groupsChecked := make(map[string]struct{})
for _, group := range groups { for _, group := range groups {
_, ok := groupsChecked[group.ID] groupsInfo = append(groupsInfo, api.GroupMinimum{
if ok { Id: group.ID,
continue Name: group.Name,
} PeersCount: len(group.Peers),
groupsChecked[group.ID] = struct{}{} })
for _, pk := range group.Peers {
if pk == peerID {
info := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
}
groupsInfo = append(groupsInfo, info)
break
}
}
} }
return groupsInfo return groupsInfo
} }

View File

@@ -39,6 +39,68 @@ const (
) )
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
}
policy := &server.Policy{
ID: "policy",
AccountID: "test_id",
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
{
ID: "rule",
Name: "rule",
Enabled: true,
Action: "accept",
Destinations: []string{"group1"},
Sources: []string{"group1"},
Bidirectional: true,
Protocol: "all",
Ports: []string{"80"},
},
},
}
srvUser := server.NewRegularUser(serviceUser)
srvUser.IsServiceUser = true
account := &server.Account{
Id: "test_id",
Domain: "hotmail.com",
Peers: peersMap,
Users: map[string]*server.User{
adminUser: server.NewAdminUser(adminUser),
regularUser: server.NewRegularUser(regularUser),
serviceUser: srvUser,
},
Groups: map[string]*nbgroup.Group{
"group1": {
ID: "group1",
AccountID: "test_id",
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
},
},
Settings: &server.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
Policies: []*server.Policy{policy},
Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
IP: net.ParseIP("100.67.0.0"),
Mask: net.IPv4Mask(255, 255, 0, 0),
},
Serial: 51,
},
}
return &PeersHandler{ return &PeersHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
@@ -67,74 +129,31 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
return peers, nil return peers, nil
}, },
GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
peersID := make([]string, len(peers))
for _, peer := range peers {
peersID = append(peersID, peer.ID)
}
return []*nbgroup.Group{
{
ID: "group1",
AccountID: accountID,
Name: "group1",
Issued: "api",
Peers: peersID,
},
}, nil
},
GetDNSDomainFunc: func() string { GetDNSDomainFunc: func() string {
return "netbird.selfhosted" return "netbird.selfhosted"
}, },
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil return claims.AccountId, claims.UserId, nil
}, },
GetAccountFunc: func(ctx context.Context, accountID string) (*server.Account, error) {
return account, nil
},
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
}
policy := &server.Policy{
ID: "policy",
AccountID: accountID,
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
{
ID: "rule",
Name: "rule",
Enabled: true,
Action: "accept",
Destinations: []string{"group1"},
Sources: []string{"group1"},
Bidirectional: true,
Protocol: "all",
Ports: []string{"80"},
},
},
}
srvUser := server.NewRegularUser(serviceUser)
srvUser.IsServiceUser = true
account := &server.Account{
Id: accountID,
Domain: "hotmail.com",
Peers: peersMap,
Users: map[string]*server.User{
adminUser: server.NewAdminUser(adminUser),
regularUser: server.NewRegularUser(regularUser),
serviceUser: srvUser,
},
Groups: map[string]*nbgroup.Group{
"group1": {
ID: "group1",
AccountID: accountID,
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
},
},
Settings: &server.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
Policies: []*server.Policy{policy},
Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
IP: net.ParseIP("100.67.0.0"),
Mask: net.IPv4Mask(255, 255, 0, 0),
},
Serial: 51,
},
}
return account, nil return account, nil
}, },
HasConnectedChannelFunc: func(peerID string) bool { HasConnectedChannelFunc: func(peerID string) bool {

View File

@@ -6,10 +6,8 @@ import (
"strconv" "strconv"
"github.com/gorilla/mux" "github.com/gorilla/mux"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -122,14 +120,9 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
return return
} }
isUpdate := policyID != "" policy := &server.Policy{
if policyID == "" {
policyID = xid.New().String()
}
policy := server.Policy{
ID: policyID, ID: policyID,
AccountID: accountID,
Name: req.Name, Name: req.Name,
Enabled: req.Enabled, Enabled: req.Enabled,
Description: req.Description, Description: req.Description,
@@ -137,6 +130,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
for _, rule := range req.Rules { for _, rule := range req.Rules {
pr := server.PolicyRule{ pr := server.PolicyRule{
ID: policyID, // TODO: when policy can contain multiple rules, need refactor ID: policyID, // TODO: when policy can contain multiple rules, need refactor
PolicyID: policyID,
Name: rule.Name, Name: rule.Name,
Destinations: rule.Destinations, Destinations: rule.Destinations,
Sources: rule.Sources, Sources: rule.Sources,
@@ -225,7 +219,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
policy.SourcePostureChecks = *req.SourcePostureChecks policy.SourcePostureChecks = *req.SourcePostureChecks
} }
if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil { policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy)
if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -236,7 +231,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
return return
} }
resp := toPolicyResponse(allGroups, &policy) resp := toPolicyResponse(allGroups, policy)
if len(resp.Rules) == 0 { if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return return

View File

@@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
} }
return policy, nil return policy, nil
}, },
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error { SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) {
if !strings.HasPrefix(policy.ID, "id-") { if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set" policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set" policy.Rules[0].ID = "id-was-set"
} }
return nil return policy, nil
}, },
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil

View File

@@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
return return
} }
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil { postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks)
if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }

View File

@@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
} }
return p, nil return p, nil
}, },
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error { SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
postureChecks.ID = "postureCheck" postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks testPostureChecks[postureChecks.ID] = postureChecks
if err := postureChecks.Validate(); err != nil { if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) //nolint return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint
} }
return nil return postureChecks, nil
}, },
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error { DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
_, ok := testPostureChecks[postureChecksID] _, ok := testPostureChecks[postureChecksID]

View File

@@ -4,6 +4,8 @@ import (
"context" "context"
"errors" "errors"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
@@ -52,30 +54,60 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
return am.Store.SaveAccount(ctx, a) return am.Store.SaveAccount(ctx, a)
} }
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) { func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
if len(groups) == 0 { if len(groupIDs) == 0 {
return true, nil return true, nil
} }
accountsGroups, err := am.ListGroups(ctx, accountId)
if err != nil { err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
return false, err for _, groupID := range groupIDs {
} _, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
for _, group := range groups { if err != nil {
var found bool return err
for _, accountGroup := range accountsGroups {
if accountGroup.ID == group {
found = true
break
} }
} }
if !found { return nil
return false, nil })
} if err != nil {
return false, err
} }
return true, nil return true, nil
} }
func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) { func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) var err error
var groups []*nbgroup.Group
var peers []*nbpeer.Peer
var settings *Settings
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
peers, err = transaction.GetAccountPeers(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
return err
})
if err != nil {
return nil, err
}
groupsMap := make(map[string]*nbgroup.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
peersMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peersMap[peer.ID] = peer
}
return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra)
} }

View File

@@ -22,9 +22,9 @@ import (
) )
type MockAccountManager struct { type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error) GetOrCreateAccountIDByUserFunc func(ctx context.Context, userId, domain string) (string, error)
GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error) GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
@@ -45,16 +45,16 @@ type MockAccountManager struct {
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*group.Group, error)
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error)
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) GetAccountInfoFromPATFunc func(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error)
MarkPATUsedFunc func(ctx context.Context, pat string) error MarkPATUsedFunc func(ctx context.Context, pat string) error
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
@@ -89,15 +89,15 @@ type MockAccountManager struct {
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error) GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error)
LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) SyncPeerFunc func(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error) GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager GetIdpManagerFunc func() idp.Manager
@@ -131,7 +131,12 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
panic("implement me") panic("implement me")
} }
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
account, err := am.GetAccountFunc(ctx, accountID)
if err != nil {
return nil, err
}
approvedPeers := make(map[string]struct{}) approvedPeers := make(map[string]struct{})
for id := range account.Peers { for id := range account.Peers {
approvedPeers[id] = struct{}{} approvedPeers[id] = struct{}{}
@@ -171,16 +176,16 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID,
return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented") return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
} }
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface // GetOrCreateAccountIDByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
func (am *MockAccountManager) GetOrCreateAccountByUser( func (am *MockAccountManager) GetOrCreateAccountIDByUser(
ctx context.Context, userId, domain string, ctx context.Context, userId, domain string,
) (*server.Account, error) { ) (string, error) {
if am.GetOrCreateAccountByUserFunc != nil { if am.GetOrCreateAccountIDByUserFunc != nil {
return am.GetOrCreateAccountByUserFunc(ctx, userId, domain) return am.GetOrCreateAccountIDByUserFunc(ctx, userId, domain)
} }
return nil, status.Errorf( return "", status.Errorf(
codes.Unimplemented, codes.Unimplemented,
"method GetOrCreateAccountByUser is not implemented", "method GetOrCreateAccountIDByUser is not implemented",
) )
} }
@@ -222,19 +227,19 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId,
} }
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error { func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error {
if am.MarkPeerConnectedFunc != nil { if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP) return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP)
} }
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
} }
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface // GetAccountInfoFromPAT mock implementation of GetAccountInfoFromPAT from server.AccountManager interface
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { func (am *MockAccountManager) GetAccountInfoFromPAT(ctx context.Context, pat string) (*server.User, *server.PersonalAccessToken, string, string, error) {
if am.GetAccountFromPATFunc != nil { if am.GetAccountInfoFromPATFunc != nil {
return am.GetAccountFromPATFunc(ctx, pat) return am.GetAccountInfoFromPATFunc(ctx, pat)
} }
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetAccountInfoFromPAT is not implemented")
} }
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface // DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
@@ -354,14 +359,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented") return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
} }
// ListGroups mock implementation of ListGroups from server.AccountManager interface
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
if am.ListGroupsFunc != nil {
return am.ListGroupsFunc(ctx, accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
}
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
if am.GroupAddPeerFunc != nil { if am.GroupAddPeerFunc != nil {
@@ -395,11 +392,11 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
} }
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface // SavePolicy mock implementation of SavePolicy from server.AccountManager interface
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error { func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) {
if am.SavePolicyFunc != nil { if am.SavePolicyFunc != nil {
return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate) return am.SavePolicyFunc(ctx, accountID, userID, policy)
} }
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
} }
// DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface // DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface
@@ -675,7 +672,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us
} }
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface // UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) {
if am.UpdateAccountSettingsFunc != nil { if am.UpdateAccountSettingsFunc != nil {
return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings) return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings)
} }
@@ -691,9 +688,9 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLo
} }
// SyncPeer mocks SyncPeer of the AccountManager interface // SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
if am.SyncPeerFunc != nil { if am.SyncPeerFunc != nil {
return am.SyncPeerFunc(ctx, sync, account) return am.SyncPeerFunc(ctx, sync, accountID)
} }
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
} }
@@ -739,11 +736,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
} }
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface // SavePostureChecks mocks SavePostureChecks of the AccountManager interface
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
if am.SavePostureChecksFunc != nil { if am.SavePostureChecksFunc != nil {
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks) return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
} }
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
} }
// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface // DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface
@@ -840,3 +837,11 @@ func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented")
} }
// GetPeerGroups mocks GetPeerGroups of the AccountManager interface
func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*group.Group, error) {
if am.GetPeerGroupsFunc != nil {
return am.GetPeerGroupsFunc(ctx, accountID, peerID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented")
}

View File

@@ -24,26 +24,34 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") return nil, status.NewUserNotPartOfAccountError()
} }
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID) if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
} }
// CreateNameServerGroup creates and saves a new nameserver group // CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
newNSGroup := &nbdns.NameServerGroup{ newNSGroup := &nbdns.NameServerGroup{
ID: xid.New().String(), ID: xid.New().String(),
AccountID: accountID,
Name: name, Name: name,
Description: description, Description: description,
NameServers: nameServerList, NameServers: nameServerList,
@@ -54,27 +62,34 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
SearchDomainsEnabled: searchDomainEnabled, SearchDomainsEnabled: searchDomainEnabled,
} }
err = validateNameServerGroup(false, newNSGroup, account) var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
return err
}
updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, newNSGroup.Groups)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if account.NameServerGroups == nil {
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup)
}
account.NameServerGroups[newNSGroup.ID] = newNSGroup
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
if anyGroupHasPeers(account, newNSGroup.Groups) {
am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return newNSGroup.Copy(), nil return newNSGroup.Copy(), nil
} }
@@ -87,59 +102,96 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
} }
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
err = validateNameServerGroup(true, nsGroupToSave, account) if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
if err != nil {
return err
}
nsGroupToSave.AccountID = accountID
if err = validateNameServerGroup(ctx, transaction, accountID, nsGroupToSave); err != nil {
return err
}
updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave)
})
if err != nil { if err != nil {
return err return err
} }
oldNSGroup := account.NameServerGroups[nsGroupToSave.ID]
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil return nil
} }
// DeleteNameServerGroup deletes nameserver group with nsGroupID // DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
nsGroup := account.NameServerGroups[nsGroupID] if user.AccountID != accountID {
if nsGroup == nil { return status.NewUserNotPartOfAccountError()
return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID)
} }
delete(account.NameServerGroups, nsGroupID)
account.Network.IncSerial() var nsGroup *nbdns.NameServerGroup
if err = am.Store.SaveAccount(ctx, account); err != nil { var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID)
if err != nil {
return err
}
updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, nsGroup.Groups)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID)
})
if err != nil {
return err return err
} }
if anyGroupHasPeers(account, nsGroup.Groups) {
am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil return nil
} }
@@ -150,44 +202,62 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
} }
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
} }
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error { func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
nsGroupID := ""
if existingGroup {
nsGroupID = nameserverGroup.ID
_, found := account.NameServerGroups[nsGroupID]
if !found {
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
}
}
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled) err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
if err != nil { if err != nil {
return err return err
} }
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
if err != nil {
return err
}
err = validateNSList(nameserverGroup.NameServers) err = validateNSList(nameserverGroup.NameServers)
if err != nil { if err != nil {
return err return err
} }
err = validateGroups(nameserverGroup.Groups, account.Groups) nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
} }
return nil err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups)
if err != nil {
return err
}
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups)
if err != nil {
return err
}
return validateGroups(nameserverGroup.Groups, groups)
}
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false, nil
}
hasPeers, err := anyGroupHasPeers(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeers(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
} }
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error { func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
@@ -213,14 +283,14 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
return nil return nil
} }
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error { func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error {
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" { if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar) return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
} }
for _, nsGroup := range nsGroupMap { for _, nsGroup := range groups {
if name == nsGroup.Name && nsGroup.ID != nsGroupID { if name == nsGroup.Name && nsGroup.ID != nsGroupID {
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name) return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name)
} }
} }
@@ -228,8 +298,8 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na
} }
func validateNSList(list []nbdns.NameServer) error { func validateNSList(list []nbdns.NameServer) error {
nsListLenght := len(list) nsListLength := len(list)
if nsListLenght == 0 || nsListLenght > 3 { if nsListLength == 0 || nsListLength > 3 {
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list)) return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list))
} }
return nil return nil
@@ -244,14 +314,7 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error {
if id == "" { if id == "" {
return status.Errorf(status.InvalidArgument, "group ID should not be empty string") return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
} }
found := false if _, found := groups[id]; !found {
for groupID := range groups {
if id == groupID {
found = true
break
}
}
if !found {
return status.Errorf(status.InvalidArgument, "group id %s not found", id) return status.Errorf(status.InvalidArgument, "group id %s not found", id)
} }
} }
@@ -277,11 +340,3 @@ func validateDomain(domain string) error {
return nil return nil
} }
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false
}
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
}

View File

@@ -6,6 +6,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -381,14 +382,14 @@ func TestCreateNameServerGroup(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestNSAccount(t, am) accountID, err := initTestNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
outNSGroup, err := am.CreateNameServerGroup( outNSGroup, err := am.CreateNameServerGroup(
context.Background(), context.Background(),
account.Id, accountID,
testCase.inputArgs.name, testCase.inputArgs.name,
testCase.inputArgs.description, testCase.inputArgs.description,
testCase.inputArgs.nameServers, testCase.inputArgs.nameServers,
@@ -609,20 +610,16 @@ func TestSaveNameServerGroup(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestNSAccount(t, am) accountID, err := initTestNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup testCase.existingNSGroup.AccountID = accountID
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testCase.existingNSGroup)
err = am.Store.SaveAccount(context.Background(), account) require.NoError(t, err, "failed to save existing nameserver group")
if err != nil {
t.Error("account should be saved")
}
var nsGroupToSave *nbdns.NameServerGroup var nsGroupToSave *nbdns.NameServerGroup
if !testCase.skipCopying { if !testCase.skipCopying {
nsGroupToSave = testCase.existingNSGroup.Copy() nsGroupToSave = testCase.existingNSGroup.Copy()
@@ -651,22 +648,17 @@ func TestSaveNameServerGroup(t *testing.T) {
} }
} }
err = am.SaveNameServerGroup(context.Background(), account.Id, userID, nsGroupToSave) err = am.SaveNameServerGroup(context.Background(), accountID, userID, nsGroupToSave)
testCase.errFunc(t, err) testCase.errFunc(t, err)
if !testCase.shouldCreate { if !testCase.shouldCreate {
return return
} }
account, err = am.Store.GetAccount(context.Background(), account.Id) savedNSGroup, err := am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testCase.expectedNSGroup.ID)
if err != nil { require.NoError(t, err, "failed to get saved nameserver group")
t.Fatal(err)
}
savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID]
require.True(t, saved)
testCase.expectedNSGroup.AccountID = accountID
if !testCase.expectedNSGroup.IsEqual(savedNSGroup) { if !testCase.expectedNSGroup.IsEqual(savedNSGroup) {
t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", savedNSGroup, testCase.expectedNSGroup) t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", savedNSGroup, testCase.expectedNSGroup)
} }
@@ -703,32 +695,25 @@ func TestDeleteNameServerGroup(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestNSAccount(t, am) accountID, err := initTestNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
account.NameServerGroups[testingNSGroup.ID] = testingNSGroup testingNSGroup.AccountID = accountID
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testingNSGroup)
require.NoError(t, err, "failed to save nameserver group")
err = am.Store.SaveAccount(context.Background(), account) err = am.DeleteNameServerGroup(context.Background(), accountID, testingNSGroup.ID, userID)
if err != nil {
t.Error("failed to save account")
}
err = am.DeleteNameServerGroup(context.Background(), account.Id, testingNSGroup.ID, userID)
if err != nil { if err != nil {
t.Error("deleting nameserver group failed with error: ", err) t.Error("deleting nameserver group failed with error: ", err)
} }
savedAccount, err := am.Store.GetAccount(context.Background(), account.Id) _, err = am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testingNSGroup.ID)
if err != nil { require.NotNil(t, err)
t.Error("failed to retrieve saved account with error: ", err) sErr, ok := status.FromError(err)
} require.True(t, ok, "error should be a status error")
assert.Equal(t, status.NotFound, sErr.Type(), "nameserver group shouldn't be found after delete")
_, found := savedAccount.NameServerGroups[testingNSGroup.ID]
if found {
t.Error("nameserver group shouldn't be found after delete")
}
} }
func TestGetNameServerGroup(t *testing.T) { func TestGetNameServerGroup(t *testing.T) {
@@ -738,12 +723,12 @@ func TestGetNameServerGroup(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestNSAccount(t, am) accountID, err := initTestNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
foundGroup, err := am.GetNameServerGroup(context.Background(), account.Id, testUserID, existingNSGroupID) foundGroup, err := am.GetNameServerGroup(context.Background(), accountID, testUserID, existingNSGroupID)
if err != nil { if err != nil {
t.Error("getting existing nameserver group failed with error: ", err) t.Error("getting existing nameserver group failed with error: ", err)
} }
@@ -752,7 +737,7 @@ func TestGetNameServerGroup(t *testing.T) {
t.Error("got a nil group while getting nameserver group with ID") t.Error("got a nil group while getting nameserver group with ID")
} }
_, err = am.GetNameServerGroup(context.Background(), account.Id, testUserID, "not existing") _, err = am.GetNameServerGroup(context.Background(), accountID, testUserID, "not existing")
if err == nil { if err == nil {
t.Error("getting not existing nameserver group should return error, got nil") t.Error("getting not existing nameserver group should return error, got nil")
} }
@@ -784,8 +769,12 @@ func createNSStore(t *testing.T) (Store, error) {
return store, nil return store, nil
} }
func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
t.Helper() t.Helper()
accountID := "testingAcc"
userID := testUserID
domain := "example.com"
peer1 := &nbpeer.Peer{ peer1 := &nbpeer.Peer{
Key: nsGroupPeer1Key, Key: nsGroupPeer1Key,
Name: "test-host1@netbird.io", Name: "test-host1@netbird.io",
@@ -816,6 +805,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
} }
existingNSGroup := nbdns.NameServerGroup{ existingNSGroup := nbdns.NameServerGroup{
ID: existingNSGroupID, ID: existingNSGroupID,
AccountID: accountID,
Name: existingNSGroupName, Name: existingNSGroupName,
Description: "", Description: "",
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
@@ -834,42 +824,42 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
Enabled: true, Enabled: true,
} }
accountID := "testingAcc" err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
userID := testUserID
domain := "example.com"
account := newAccountWithId(context.Background(), accountID, userID, domain)
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup
newGroup1 := &nbgroup.Group{
ID: group1ID,
Name: group1ID,
}
newGroup2 := &nbgroup.Group{
ID: group2ID,
Name: group2ID,
}
account.Groups[newGroup1.ID] = newGroup1
account.Groups[newGroup2.ID] = newGroup2
err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return "", err
}
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &existingNSGroup)
if err != nil {
return "", err
}
err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*nbgroup.Group{
{
ID: group1ID,
AccountID: accountID,
Name: group1ID,
},
{
ID: group2ID,
AccountID: accountID,
Name: group2ID,
},
})
if err != nil {
return "", err
} }
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer1) _, _, _, err = am.AddPeer(context.Background(), "", userID, peer1)
if err != nil { if err != nil {
return nil, err return "", err
} }
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer2) _, _, _, err = am.AddPeer(context.Background(), "", userID, peer2)
if err != nil { if err != nil {
return nil, err return "", err
} }
return account, nil return accountID, nil
} }
func TestValidateDomain(t *testing.T) { func TestValidateDomain(t *testing.T) {

File diff suppressed because it is too large Load Diff

View File

@@ -44,7 +44,7 @@ type Peer struct {
// CreatedAt records the time the peer was created // CreatedAt records the time the peer was created
CreatedAt time.Time CreatedAt time.Time
// Indicate ephemeral peer attribute // Indicate ephemeral peer attribute
Ephemeral bool Ephemeral bool `gorm:"index"`
// Geo location based on connection IP // Geo location based on connection IP
Location Location `gorm:"embedded;embeddedPrefix:location_"` Location Location `gorm:"embedded;embeddedPrefix:location_"`
} }

View File

@@ -13,6 +13,7 @@ import (
"testing" "testing"
"time" "time"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -22,7 +23,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -283,14 +283,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
var ( var (
group1 nbgroup.Group group1 nbgroup.Group
group2 nbgroup.Group group2 nbgroup.Group
policy Policy
) )
group1.ID = xid.New().String() group1.ID = xid.New().String()
group2.ID = xid.New().String() group2.ID = xid.New().String()
group1.Name = "src" group1.Name = "src"
group2.Name = "dst" group2.Name = "dst"
policy.ID = xid.New().String()
group1.Peers = append(group1.Peers, peer1.ID) group1.Peers = append(group1.Peers, peer1.ID)
group2.Peers = append(group2.Peers, peer2.ID) group2.Peers = append(group2.Peers, peer2.ID)
@@ -305,18 +303,20 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
return return
} }
policy.Name = "test" policy := &Policy{
policy.Enabled = true Name: "test",
policy.Rules = []*PolicyRule{ Enabled: true,
{ Rules: []*PolicyRule{
Enabled: true, {
Sources: []string{group1.ID}, Enabled: true,
Destinations: []string{group2.ID}, Sources: []string{group1.ID},
Bidirectional: true, Destinations: []string{group2.ID},
Action: PolicyTrafficActionAccept, Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
}, },
} }
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
if err != nil { if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err) t.Errorf("expecting rule to be added, got failure %v", err)
return return
@@ -364,7 +364,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
} }
policy.Enabled = false policy.Enabled = false
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
if err != nil { if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err) t.Errorf("expecting rule to be added, got failure %v", err)
return return
@@ -468,21 +468,25 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
accountID := "test_account" accountID := "test_account"
adminUser := "account_creator" adminUser := "account_creator"
someUser := "some_user" someUser := "some_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "") err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
account.Users[someUser] = &User{ require.NoError(t, err, "failed to create account")
Id: someUser,
Role: UserRoleUser,
}
account.Settings.RegularUsersViewBlocked = false
err = manager.Store.SaveAccount(context.Background(), account) err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
if err != nil { Id: someUser,
t.Fatal(err) AccountID: accountID,
return Role: UserRoleUser,
} })
require.NoError(t, err, "failed to create user")
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account settings")
settings.RegularUsersViewBlocked = false
err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings)
require.NoError(t, err, "failed to save account settings")
// two peers one added by a regular user and one with a setup key // two peers one added by a regular user and one with a setup key
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
if err != nil { if err != nil {
t.Fatal("error creating setup key") t.Fatal("error creating setup key")
return return
@@ -536,7 +540,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
assert.NotNil(t, peer) assert.NotNil(t, peer)
// delete the all-to-all policy so that user's peer1 has no access to peer2 // delete the all-to-all policy so that user's peer1 has no access to peer2
for _, policy := range account.Policies { accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account policies")
for _, policy := range accountPolicies {
err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser) err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -655,21 +662,33 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
accountID := "test_account" accountID := "test_account"
adminUser := "account_creator" adminUser := "account_creator"
someUser := "some_user" someUser := "some_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account.Users[someUser] = &User{ err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
require.NoError(t, err, "failed to create account")
err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: someUser, Id: someUser,
AccountID: accountID,
Role: testCase.role, Role: testCase.role,
IsServiceUser: testCase.isServiceUser, IsServiceUser: testCase.isServiceUser,
} })
account.Policies = []*Policy{} require.NoError(t, err, "failed to create user")
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
err = manager.Store.SaveAccount(context.Background(), account) accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID)
if err != nil { require.NoError(t, err, "failed to get account policies")
t.Fatal(err)
return for _, policy := range accountPolicies {
err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser)
require.NoError(t, err, "failed to delete policy")
} }
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account settings")
settings.RegularUsersViewBlocked = testCase.limitedViewSettings
err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings)
require.NoError(t, err, "failed to save account settings")
peerKey1, err := wgtypes.GeneratePrivateKey() peerKey1, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -725,10 +744,18 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
adminUser := "account_creator" adminUser := "account_creator"
regularUser := "regular_user" regularUser := "regular_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "") err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
account.Users[regularUser] = &User{ if err != nil {
Id: regularUser, return nil, "", "", err
Role: UserRoleUser, }
err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: regularUser,
AccountID: accountID,
Role: UserRoleUser,
})
if err != nil {
return nil, "", "", err
} }
// Create peers // Create peers
@@ -742,31 +769,40 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
Status: &nbpeer.PeerStatus{}, Status: &nbpeer.PeerStatus{},
UserID: regularUser, UserID: regularUser,
} }
account.Peers[peer.ID] = peer err = manager.Store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, peer)
if err != nil {
return nil, "", "", err
}
} }
// Create groups and policies // Create groups and policies
account.Policies = make([]*Policy, 0, groups)
for i := 0; i < groups; i++ { for i := 0; i < groups; i++ {
groupID := fmt.Sprintf("group-%d", i) groupID := fmt.Sprintf("group-%d", i)
group := &nbgroup.Group{ group := &nbgroup.Group{
ID: groupID, ID: groupID,
Name: fmt.Sprintf("Group %d", i), AccountID: accountID,
Name: fmt.Sprintf("Group %d", i),
} }
for j := 0; j < peers/groups; j++ { for j := 0; j < peers/groups; j++ {
peerIndex := i*(peers/groups) + j peerIndex := i*(peers/groups) + j
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
} }
account.Groups[groupID] = group
err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
if err != nil {
return nil, "", "", err
}
// Create a policy for this group // Create a policy for this group
policy := &Policy{ policy := &Policy{
ID: fmt.Sprintf("policy-%d", i), ID: fmt.Sprintf("policy-%d", i),
Name: fmt.Sprintf("Policy for Group %d", i), AccountID: accountID,
Enabled: true, Name: fmt.Sprintf("Policy for Group %d", i),
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: fmt.Sprintf("rule-%d", i), ID: fmt.Sprintf("rule-%d", i),
PolicyID: fmt.Sprintf("policy-%d", i),
Name: fmt.Sprintf("Rule for Group %d", i), Name: fmt.Sprintf("Rule for Group %d", i),
Enabled: true, Enabled: true,
Sources: []string{groupID}, Sources: []string{groupID},
@@ -777,22 +813,23 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
}, },
}, },
} }
account.Policies = append(account.Policies, policy)
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
if err != nil {
return nil, "", "", err
}
} }
account.PostureChecks = []*posture.Checks{ err = manager.Store.SavePostureChecks(context.Background(), LockingStrengthUpdate, &posture.Checks{
{ ID: "PostureChecksAll",
ID: "PostureChecksAll", AccountID: accountID,
Name: "All", Name: "All",
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1", MinVersion: "0.0.1",
},
}, },
}, },
} })
err = manager.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, "", "", err return nil, "", "", err
} }
@@ -877,7 +914,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
start := time.Now() start := time.Now()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.updateAccountPeers(ctx, account) manager.updateAccountPeers(ctx, account.Id)
} }
duration := time.Since(start) duration := time.Since(start)
@@ -1445,9 +1482,9 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
// Adding peer to group linked with policy should update account peers and send peer update // Adding peer to group linked with policy should update account peers and send peer update
t.Run("adding peer to group linked with policy", func(t *testing.T) { t.Run("adding peer to group linked with policy", func(t *testing.T) {
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy", AccountID: account.Id,
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
Enabled: true, Enabled: true,
@@ -1457,7 +1494,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
}, false) })
require.NoError(t, err) require.NoError(t, err)
done := make(chan struct{}) done := make(chan struct{})

View File

@@ -41,6 +41,7 @@ type PersonalAccessToken struct {
func (t *PersonalAccessToken) Copy() *PersonalAccessToken { func (t *PersonalAccessToken) Copy() *PersonalAccessToken {
return &PersonalAccessToken{ return &PersonalAccessToken{
ID: t.ID, ID: t.ID,
UserID: t.UserID,
Name: t.Name, Name: t.Name,
HashedToken: t.HashedToken, HashedToken: t.HashedToken,
ExpirationDate: t.ExpirationDate, ExpirationDate: t.ExpirationDate,
@@ -58,7 +59,7 @@ type PersonalAccessTokenGenerated struct {
// CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User. // CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User.
// Additionally, it will return the token in plain text once, to give to the user and only save a hashed version // Additionally, it will return the token in plain text once, to give to the user and only save a hashed version
func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) { func CreateNewPAT(name string, expirationInDays int, targetID, createdBy string) (*PersonalAccessTokenGenerated, error) {
hashedToken, plainToken, err := generateNewToken() hashedToken, plainToken, err := generateNewToken()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -67,6 +68,7 @@ func CreateNewPAT(name string, expirationInDays int, createdBy string) (*Persona
return &PersonalAccessTokenGenerated{ return &PersonalAccessTokenGenerated{
PersonalAccessToken: PersonalAccessToken{ PersonalAccessToken: PersonalAccessToken{
ID: xid.New().String(), ID: xid.New().String(),
UserID: targetID,
Name: name, Name: name,
HashedToken: hashedToken, HashedToken: hashedToken,
ExpirationDate: currentTime.AddDate(0, 0, expirationInDays), ExpirationDate: currentTime.AddDate(0, 0, expirationInDays),

View File

@@ -3,13 +3,13 @@ package server
import ( import (
"context" "context"
_ "embed" _ "embed"
"slices"
"strconv" "strconv"
"strings" "strings"
"github.com/netbirdio/netbird/management/proto"
"github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -125,6 +125,7 @@ type PolicyRule struct {
func (pm *PolicyRule) Copy() *PolicyRule { func (pm *PolicyRule) Copy() *PolicyRule {
rule := &PolicyRule{ rule := &PolicyRule{
ID: pm.ID, ID: pm.ID,
PolicyID: pm.PolicyID,
Name: pm.Name, Name: pm.Name,
Description: pm.Description, Description: pm.Description,
Enabled: pm.Enabled, Enabled: pm.Enabled,
@@ -171,6 +172,7 @@ type Policy struct {
func (p *Policy) Copy() *Policy { func (p *Policy) Copy() *Policy {
c := &Policy{ c := &Policy{
ID: p.ID, ID: p.ID,
AccountID: p.AccountID,
Name: p.Name, Name: p.Name,
Description: p.Description, Description: p.Description,
Enabled: p.Enabled, Enabled: p.Enabled,
@@ -343,44 +345,72 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") return nil, status.NewUserNotPartOfAccountError()
} }
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
} }
// SavePolicy in the store // SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return nil, err
} }
updateAccountPeers, err := am.savePolicy(account, policy, isUpdate) if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
var isUpdate = policy.ID != ""
var updateAccountPeers bool
var action = activity.PolicyAdded
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
return err
}
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
saveFunc := transaction.CreatePolicy
if isUpdate {
action = activity.PolicyUpdated
saveFunc = transaction.SavePolicy
}
return saveFunc(ctx, LockingStrengthUpdate, policy)
})
if err != nil { if err != nil {
return err return nil, err
} }
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
action := activity.PolicyAdded
if isUpdate {
action = activity.PolicyUpdated
}
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers { if updateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return policy, nil
} }
// DeletePolicy from the store // DeletePolicy from the store
@@ -388,112 +418,136 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
policy, err := am.deletePolicy(account, policyID) if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var policy *Policy
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
policy, err = transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
if err != nil {
return err
}
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID)
})
if err != nil { if err != nil {
return err return err
} }
account.Network.IncSerial() am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
if anyGroupHasPeers(account, policy.ruleGroups()) {
am.updateAccountPeers(ctx, account)
} }
return nil return nil
} }
// ListPolicies from the store // ListPolicies from the store.
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
} }
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
} }
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { // arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
policyIdx := -1 func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) {
for i, policy := range account.Policies {
if policy.ID == policyID {
policyIdx = i
break
}
}
if policyIdx < 0 {
return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID)
}
policy := account.Policies[policyIdx]
account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...)
return policy, nil
}
// savePolicy saves or updates a policy in the given account.
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
for index, rule := range policyToSave.Rules {
rule.Sources = filterValidGroupIDs(account, rule.Sources)
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
policyToSave.Rules[index] = rule
}
if policyToSave.SourcePostureChecks != nil {
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
}
if isUpdate { if isUpdate {
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
if policyIdx < 0 { if err != nil {
return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) return false, err
} }
oldPolicy := account.Policies[policyIdx] if !policy.Enabled && !existingPolicy.Enabled {
// Update the existing policy
account.Policies[policyIdx] = policyToSave
if !policyToSave.Enabled && !oldPolicy.Enabled {
return false, nil return false, nil
} }
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
return updateAccountPeers, nil hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
} }
// Add the new policy to the account return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
account.Policies = append(account.Policies, policyToSave)
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
} }
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { // validatePolicy validates the policy and its rules.
result := make([]*proto.FirewallRule, len(rules)) func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error {
for i := range rules { if policy.ID != "" {
rule := rules[i] _, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
if err != nil {
result[i] = &proto.FirewallRule{ return err
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
} }
} else {
policy.ID = xid.New().String()
policy.AccountID = accountID
} }
return result
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups())
if err != nil {
return err
}
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks)
if err != nil {
return err
}
for i, rule := range policy.Rules {
ruleCopy := rule.Copy()
if ruleCopy.ID == "" {
ruleCopy.ID = xid.New().String()
ruleCopy.PolicyID = policy.ID
}
ruleCopy.Sources = getValidGroupIDs(groups, ruleCopy.Sources)
ruleCopy.Destinations = getValidGroupIDs(groups, ruleCopy.Destinations)
policy.Rules[i] = ruleCopy
}
if policy.SourcePostureChecks != nil {
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
}
return nil
} }
// getAllPeersFromGroups for given peer ID and list of groups // getAllPeersFromGroups for given peer ID and list of groups
@@ -574,27 +628,42 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
return nil return nil
} }
// filterValidPostureChecks filters and returns the posture check IDs from the given list // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
// that are valid within the provided account. func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string {
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { validIDs := make([]string, 0, len(postureChecksIds))
result := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds { for _, id := range postureChecksIds {
for _, postureCheck := range account.PostureChecks { if _, exists := postureChecks[id]; exists {
if id == postureCheck.ID { validIDs = append(validIDs, id)
result = append(result, id)
continue
}
} }
} }
return result
return validIDs
} }
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. // getValidGroupIDs filters and returns only the valid group IDs from the provided list.
func filterValidGroupIDs(account *Account, groupIDs []string) []string { func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string {
result := make([]string, 0, len(groupIDs)) validIDs := make([]string, 0, len(groupIDs))
for _, groupID := range groupIDs { for _, id := range groupIDs {
if _, exists := account.Groups[groupID]; exists { if _, exists := groups[id]; exists {
result = append(result, groupID) validIDs = append(validIDs, id)
}
}
return validIDs
}
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.FirewallRule{
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
} }
} }
return result return result

View File

@@ -7,7 +7,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/rs/xid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
@@ -859,14 +858,23 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
}) })
var policyWithGroupRulesNoPeers *Policy
var policyWithDestinationPeersOnly *Policy
var policyWithSourceAndDestinationPeers *Policy
// Saving policy with rule groups with no peers should not update account's peers and not send peer update // Saving policy with rule groups with no peers should not update account's peers and not send peer update
t.Run("saving policy with rule groups with no peers", func(t *testing.T) { t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
policy := Policy{ done := make(chan struct{})
ID: "policy-rule-groups-no-peers", go func() {
Enabled: true, peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(),
Enabled: true, Enabled: true,
Sources: []string{"groupB"}, Sources: []string{"groupB"},
Destinations: []string{"groupC"}, Destinations: []string{"groupC"},
@@ -874,15 +882,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -895,12 +895,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Saving policy with source group containing peers, but destination group without peers should // Saving policy with source group containing peers, but destination group without peers should
// update account's peers and send peer update // update account's peers and send peer update
t.Run("saving policy where source has peers but destination does not", func(t *testing.T) { t.Run("saving policy where source has peers but destination does not", func(t *testing.T) {
policy := Policy{ done := make(chan struct{})
ID: "policy-source-has-peers-destination-none", go func() {
Enabled: true, peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(),
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupB"}, Destinations: []string{"groupB"},
@@ -909,15 +914,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -930,13 +927,18 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Saving policy with destination group containing peers, but source group without peers should // Saving policy with destination group containing peers, but source group without peers should
// update account's peers and send peer update // update account's peers and send peer update
t.Run("saving policy where destination has peers but source does not", func(t *testing.T) { t.Run("saving policy where destination has peers but source does not", func(t *testing.T) {
policy := Policy{ done := make(chan struct{})
ID: "policy-destination-has-peers-source-none", go func() {
Enabled: true, peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(), Enabled: true,
Enabled: false,
Sources: []string{"groupC"}, Sources: []string{"groupC"},
Destinations: []string{"groupD"}, Destinations: []string{"groupD"},
Bidirectional: true, Bidirectional: true,
@@ -944,15 +946,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -965,12 +959,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Saving policy with destination and source groups containing peers should update account's peers // Saving policy with destination and source groups containing peers should update account's peers
// and send peer update // and send peer update
t.Run("saving policy with source and destination groups with peers", func(t *testing.T) { t.Run("saving policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{ done := make(chan struct{})
ID: "policy-source-destination-peers", go func() {
Enabled: true, peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(),
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupD"}, Destinations: []string{"groupD"},
@@ -978,15 +977,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -999,28 +990,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Disabling policy with destination and source groups containing peers should update account's peers // Disabling policy with destination and source groups containing peers should update account's peers
// and send peer update // and send peer update
t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) { t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Enabled: false,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldReceiveUpdate(t, updMsg) peerShouldReceiveUpdate(t, updMsg)
close(done) close(done)
}() }()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) policyWithSourceAndDestinationPeers.Enabled = false
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -1033,29 +1010,15 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Updating disabled policy with destination and source groups containing peers should not update account's peers // Updating disabled policy with destination and source groups containing peers should not update account's peers
// or send peer update // or send peer update
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) { t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Description: "updated description",
Enabled: false,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldNotReceiveUpdate(t, updMsg) peerShouldNotReceiveUpdate(t, updMsg)
close(done) close(done)
}() }()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) policyWithSourceAndDestinationPeers.Description = "updated description"
policyWithSourceAndDestinationPeers.Rules[0].Destinations = []string{"groupA"}
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -1068,28 +1031,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Enabling policy with destination and source groups containing peers should update account's peers // Enabling policy with destination and source groups containing peers should update account's peers
// and send peer update // and send peer update
t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) { t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldReceiveUpdate(t, updMsg) peerShouldReceiveUpdate(t, updMsg)
close(done) close(done)
}() }()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) policyWithSourceAndDestinationPeers.Enabled = true
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -1101,15 +1050,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Deleting policy should trigger account peers update and send peer update // Deleting policy should trigger account peers update and send peer update
t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) { t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) {
policyID := "policy-source-destination-peers"
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldReceiveUpdate(t, updMsg) peerShouldReceiveUpdate(t, updMsg)
close(done) close(done)
}() }()
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) err := manager.DeletePolicy(context.Background(), account.Id, policyWithSourceAndDestinationPeers.ID, userID)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -1123,14 +1070,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Deleting policy with destination group containing peers, but source group without peers should // Deleting policy with destination group containing peers, but source group without peers should
// update account's peers and send peer update // update account's peers and send peer update
t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) { t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) {
policyID := "policy-destination-has-peers-source-none"
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldReceiveUpdate(t, updMsg) peerShouldReceiveUpdate(t, updMsg)
close(done) close(done)
}() }()
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) err := manager.DeletePolicy(context.Background(), account.Id, policyWithDestinationPeersOnly.ID, userID)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -1142,14 +1088,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Deleting policy with no peers in groups should not update account's peers and not send peer update // Deleting policy with no peers in groups should not update account's peers and not send peer update
t.Run("deleting policy with no peers in groups", func(t *testing.T) { t.Run("deleting policy with no peers in groups", func(t *testing.T) {
policyID := "policy-rule-groups-no-peers"
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldNotReceiveUpdate(t, updMsg) peerShouldNotReceiveUpdate(t, updMsg)
close(done) close(done)
}() }()
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) err := manager.DeletePolicy(context.Background(), account.Id, policyWithGroupRulesNoPeers.ID, userID)
assert.NoError(t, err) assert.NoError(t, err)
select { select {

View File

@@ -7,8 +7,6 @@ import (
"regexp" "regexp"
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh
} }
func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) { func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) {
if postureChecksID == "" {
postureChecksID = xid.New().String()
}
postureChecks := Checks{ postureChecks := Checks{
ID: postureChecksID, ID: postureChecksID,
Name: name, Name: name,

View File

@@ -2,16 +2,15 @@ package server
import ( import (
"context" "context"
"fmt"
"slices" "slices"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) "github.com/rs/xid"
log "github.com/sirupsen/logrus"
const ( "golang.org/x/exp/maps"
errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks"
) )
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
@@ -20,219 +19,279 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
return nil, err return nil, err
} }
if !user.HasAdminPower() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) return nil, status.NewUserNotPartOfAccountError()
}
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
}
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
} }
if !user.HasAdminPower() { if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) return nil, status.NewAdminPermissionError()
} }
if err := postureChecks.Validate(); err != nil { return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
return status.Errorf(status.InvalidArgument, err.Error()) //nolint }
// SavePostureChecks saves a posture check.
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
} }
exists, uniqName := am.savePostureChecks(account, postureChecks) if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
// we do not allow create new posture checks with non uniq name
if !exists && !uniqName {
return status.Errorf(status.PreconditionFailed, "Posture check name should be unique")
} }
action := activity.PostureCheckCreated if !user.HasAdminPower() {
if exists { return nil, status.NewAdminPermissionError()
action = activity.PostureCheckUpdated
account.Network.IncSerial()
} }
if err = am.Store.SaveAccount(ctx, account); err != nil { var updateAccountPeers bool
return err var isUpdate = postureChecks.ID != ""
var action = activity.PostureCheckCreated
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
return err
}
if isUpdate {
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
action = activity.PostureCheckUpdated
}
postureChecks.AccountID = accountID
return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks)
})
if err != nil {
return nil, err
} }
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { if updateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return postureChecks, nil
} }
// DeletePostureChecks deletes a posture check by ID.
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
user, err := account.FindUser(userID) if user.AccountID != accountID {
if err != nil { return status.NewUserNotPartOfAccountError()
return err
} }
if !user.HasAdminPower() { if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) return status.NewAdminPermissionError()
} }
postureChecks, err := am.deletePostureChecks(account, postureChecksID) var postureChecks *posture.Checks
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
if err != nil {
return err
}
if err = isPostureCheckLinkedToPolicy(ctx, transaction, postureChecksID, accountID); err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID)
})
if err != nil { if err != nil {
return err return err
} }
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta()) am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta())
return nil return nil
} }
// ListPostureChecks returns a list of posture checks.
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !user.HasAdminPower() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) return nil, status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return nil, status.NewAdminPermissionError()
} }
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
} }
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
uniqName = true
for i, p := range account.PostureChecks {
if !exists && p.ID == postureChecks.ID {
account.PostureChecks[i] = postureChecks
exists = true
}
if p.Name == postureChecks.Name {
uniqName = false
}
}
if !exists {
account.PostureChecks = append(account.PostureChecks, postureChecks)
}
return
}
func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) {
postureChecksIdx := -1
for i, postureChecks := range account.PostureChecks {
if postureChecks.ID == postureChecksID {
postureChecksIdx = i
break
}
}
if postureChecksIdx < 0 {
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
}
// Check if posture check is linked to any policy
if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked {
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name)
}
postureChecks := account.PostureChecks[postureChecksIdx]
account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...)
return postureChecks, nil
}
// getPeerPostureChecks returns the posture checks applied for a given peer. // getPeerPostureChecks returns the posture checks applied for a given peer.
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks { func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) {
peerPostureChecks := make(map[string]posture.Checks) peerPostureChecks := make(map[string]*posture.Checks)
if len(account.PostureChecks) == 0 { err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
if len(postureChecks) == 0 {
return nil
}
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
for _, policy := range policies {
if !policy.Enabled {
continue
}
if err = addPolicyPostureChecks(ctx, transaction, accountID, peerID, policy, peerPostureChecks); err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return maps.Values(peerPostureChecks), nil
}
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
return false, err
}
for _, policy := range policies {
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
}
}
return false, nil
}
// validatePostureChecks validates the posture checks.
func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error {
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
}
// If the posture check already has an ID, verify its existence in the store.
if postureChecks.ID != "" {
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
return err
}
return nil return nil
} }
for _, policy := range account.Policies { // For new posture checks, ensure no duplicates by name.
if !policy.Enabled { checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
continue if err != nil {
} return err
}
if isPeerInPolicySourceGroups(peer.ID, account, policy) { for _, check := range checks {
addPolicyPostureChecks(account, policy, peerPostureChecks) if check.Name == postureChecks.Name && check.ID != postureChecks.ID {
return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name)
} }
} }
postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks)) postureChecks.ID = xid.New().String()
for _, check := range peerPostureChecks {
checkCopy := check return nil
postureChecksList = append(postureChecksList, &checkCopy) }
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
isInGroup, err := isPeerInPolicySourceGroups(ctx, transaction, accountID, peerID, policy)
if err != nil {
return err
} }
return postureChecksList if !isInGroup {
return nil
}
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
postureCheck, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID)
if err != nil {
return err
}
peerPostureChecks[sourcePostureCheckID] = postureCheck
}
return nil
} }
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. // isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool { func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy) (bool, error) {
for _, rule := range policy.Rules { for _, rule := range policy.Rules {
if !rule.Enabled { if !rule.Enabled {
continue continue
} }
for _, sourceGroup := range rule.Sources { for _, sourceGroup := range rule.Sources {
group, ok := account.Groups[sourceGroup] group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup)
if ok && slices.Contains(group.Peers, peerID) { if err != nil {
return true log.WithContext(ctx).Debugf("failed to check peer in policy source group: %v", err)
return false, fmt.Errorf("failed to check peer in policy source group: %w", err)
}
if slices.Contains(group.Peers, peerID) {
return true, nil
} }
} }
} }
return false
}
func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) {
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
for _, postureCheck := range account.PostureChecks {
if postureCheck.ID == sourcePostureCheckID {
peerPostureChecks[sourcePostureCheckID] = *postureCheck
}
}
}
}
func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) {
for _, policy := range account.Policies {
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
return true, policy
}
}
return false, nil return false, nil
} }
// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers. // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool { func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error {
if !exists { policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
return false if err != nil {
return err
} }
isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID) for _, policy := range policies {
if !isLinked { if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
return false return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
}
} }
return anyGroupHasPeers(account, linkedPolicy.ruleGroups())
return nil
} }

View File

@@ -5,8 +5,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/rs/xid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/group"
@@ -16,7 +16,6 @@ import (
const ( const (
adminUserID = "adminUserID" adminUserID = "adminUserID"
regularUserID = "regularUserID" regularUserID = "regularUserID"
postureCheckID = "existing-id"
postureCheckName = "Existing check" postureCheckName = "Existing check"
) )
@@ -26,23 +25,22 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestPostureChecksAccount(am) accountID, err := initTestPostureChecksAccount(am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
t.Run("Generic posture check flow", func(t *testing.T) { t.Run("Generic posture check flow", func(t *testing.T) {
// regular users can not create checks // regular users can not create checks
err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) _, err = am.SavePostureChecks(context.Background(), accountID, regularUserID, &posture.Checks{})
assert.Error(t, err) assert.Error(t, err)
// regular users cannot list check // regular users cannot list check
_, err = am.ListPostureChecks(context.Background(), account.Id, regularUserID) _, err = am.ListPostureChecks(context.Background(), accountID, regularUserID)
assert.Error(t, err) assert.Error(t, err)
// should be possible to create posture check with uniq name // should be possible to create posture check with uniq name
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ postureCheck, err := am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{
ID: postureCheckID,
Name: postureCheckName, Name: postureCheckName,
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
@@ -53,13 +51,12 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// admin users can list check // admin users can list check
checks, err := am.ListPostureChecks(context.Background(), account.Id, adminUserID) checks, err := am.ListPostureChecks(context.Background(), accountID, adminUserID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, checks, 1) assert.Len(t, checks, 1)
// should not be possible to create posture check with non uniq name // should not be possible to create posture check with non uniq name
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ _, err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{
ID: "new-id",
Name: postureCheckName, Name: postureCheckName,
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
GeoLocationCheck: &posture.GeoLocationCheck{ GeoLocationCheck: &posture.GeoLocationCheck{
@@ -74,53 +71,53 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
// admins can update posture checks // admins can update posture checks
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ postureCheck.Checks = posture.ChecksDefinition{
ID: postureCheckID, NBVersionCheck: &posture.NBVersionCheck{
Name: postureCheckName, MinVersion: "0.27.0",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.27.0",
},
}, },
}) }
_, err = am.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheck)
assert.NoError(t, err) assert.NoError(t, err)
// users should not be able to delete posture checks // users should not be able to delete posture checks
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID) err = am.DeletePostureChecks(context.Background(), accountID, postureCheck.ID, regularUserID)
assert.Error(t, err) assert.Error(t, err)
// admin should be able to delete posture checks // admin should be able to delete posture checks
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID) err = am.DeletePostureChecks(context.Background(), accountID, postureCheck.ID, adminUserID)
assert.NoError(t, err) assert.NoError(t, err)
checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID) checks, err = am.ListPostureChecks(context.Background(), accountID, adminUserID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, checks, 0) assert.Len(t, checks, 0)
}) })
} }
func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { func initTestPostureChecksAccount(am *DefaultAccountManager) (string, error) {
accountID := "testingAccount" accountID := "testingAccount"
domain := "example.com" domain := "example.com"
admin := &User{ err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain)
Id: adminUserID,
Role: UserRoleAdmin,
}
user := &User{
Id: regularUserID,
Role: UserRoleUser,
}
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
account.Users[admin.Id] = admin
account.Users[user.Id] = user
err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return "", err
} }
return am.Store.GetAccount(context.Background(), account.Id) err = am.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
{
Id: adminUserID,
AccountID: accountID,
Role: UserRoleAdmin,
},
{
Id: regularUserID,
AccountID: accountID,
Role: UserRoleUser,
},
})
if err != nil {
return "", err
}
return accountID, nil
} }
func TestPostureCheckAccountPeersUpdate(t *testing.T) { func TestPostureCheckAccountPeersUpdate(t *testing.T) {
@@ -150,9 +147,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
}) })
postureCheck := posture.Checks{ postureCheckA := &posture.Checks{
ID: "postureCheck", Name: "postureCheckA",
Name: "postureCheck", AccountID: account.Id,
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"},
},
},
},
}
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA)
require.NoError(t, err)
postureCheckB := &posture.Checks{
Name: "postureCheckB",
AccountID: account.Id, AccountID: account.Id,
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
@@ -169,7 +179,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -187,12 +197,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
postureCheck.Checks = posture.ChecksDefinition{ postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0", MinVersion: "0.29.0",
}, },
} }
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -202,12 +212,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
} }
}) })
policy := Policy{ policy := &Policy{
ID: "policyA",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(),
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupA"}, Destinations: []string{"groupA"},
@@ -215,7 +223,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
SourcePostureChecks: []string{postureCheck.ID}, SourcePostureChecks: []string{postureCheckB.ID},
} }
// Linking posture check to policy should trigger update account peers and send peer update // Linking posture check to policy should trigger update account peers and send peer update
@@ -226,7 +234,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -238,7 +246,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// Updating linked posture checks should update account peers and send peer update // Updating linked posture checks should update account peers and send peer update
t.Run("updating linked to posture check with peers", func(t *testing.T) { t.Run("updating linked to posture check with peers", func(t *testing.T) {
postureCheck.Checks = posture.ChecksDefinition{ postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0", MinVersion: "0.29.0",
}, },
@@ -255,7 +263,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -274,8 +282,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}() }()
policy.SourcePostureChecks = []string{} policy.SourcePostureChecks = []string{}
_, err := manager.SavePolicy(context.Background(), account.Id, userID, policy)
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -293,7 +300,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID) err := manager.DeletePostureChecks(context.Background(), account.Id, postureCheckA.ID, userID)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -303,17 +310,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
} }
}) })
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err) assert.NoError(t, err)
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) { t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
policy = Policy{ _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policyB",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(),
Enabled: true, Enabled: true,
Sources: []string{"groupB"}, Sources: []string{"groupB"},
Destinations: []string{"groupC"}, Destinations: []string{"groupC"},
@@ -321,9 +326,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
SourcePostureChecks: []string{postureCheck.ID}, SourcePostureChecks: []string{postureCheckB.ID},
} })
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err) assert.NoError(t, err)
done := make(chan struct{}) done := make(chan struct{})
@@ -332,12 +336,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
postureCheck.Checks = posture.ChecksDefinition{ postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0", MinVersion: "0.29.0",
}, },
} }
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -354,12 +358,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
t.Cleanup(func() { t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
}) })
policy = Policy{
ID: "policyB", _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(),
Enabled: true, Enabled: true,
Sources: []string{"groupB"}, Sources: []string{"groupB"},
Destinations: []string{"groupA"}, Destinations: []string{"groupA"},
@@ -367,10 +370,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
SourcePostureChecks: []string{postureCheck.ID}, SourcePostureChecks: []string{postureCheckB.ID},
} })
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err) assert.NoError(t, err)
done := make(chan struct{}) done := make(chan struct{})
@@ -379,12 +380,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
postureCheck.Checks = posture.ChecksDefinition{ postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0", MinVersion: "0.29.0",
}, },
} }
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -397,8 +398,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// Updating linked client posture check to policy where source has peers but destination does not, // Updating linked client posture check to policy where source has peers but destination does not,
// should trigger account peers update and send peer update // should trigger account peers update and send peer update
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
policy = Policy{ _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policyB",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -409,9 +409,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
SourcePostureChecks: []string{postureCheck.ID}, SourcePostureChecks: []string{postureCheckB.ID},
} })
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err) assert.NoError(t, err)
done := make(chan struct{}) done := make(chan struct{})
@@ -420,7 +419,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
postureCheck.Checks = posture.ChecksDefinition{ postureCheckB.Checks = posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{ ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{ Processes: []posture.Process{
{ {
@@ -429,7 +428,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}, },
}, },
} }
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -440,80 +439,120 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}) })
} }
func TestArePostureCheckChangesAffectingPeers(t *testing.T) { func TestArePostureCheckChangesAffectPeers(t *testing.T) {
account := &Account{ manager, err := createManager(t)
Policies: []*Policy{ require.NoError(t, err, "failed to create account manager")
{
ID: "policyA", accountID, err := initTestPostureChecksAccount(manager)
Rules: []*PolicyRule{ require.NoError(t, err, "failed to init testing account")
{
Enabled: true, groupA := &group.Group{
Sources: []string{"groupA"}, ID: "groupA",
Destinations: []string{"groupA"}, AccountID: accountID,
}, Peers: []string{"peer1"},
},
SourcePostureChecks: []string{"checkA"},
},
},
Groups: map[string]*group.Group{
"groupA": {
ID: "groupA",
Peers: []string{"peer1"},
},
"groupB": {
ID: "groupB",
Peers: []string{},
},
},
PostureChecks: []*posture.Checks{
{
ID: "checkA",
},
{
ID: "checkB",
},
},
} }
groupB := &group.Group{
ID: "groupB",
AccountID: accountID,
Peers: []string{},
}
err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB})
require.NoError(t, err, "failed to save groups")
postureCheckA := &posture.Checks{
Name: "checkA",
AccountID: accountID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
},
}
postureCheckA, err = manager.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheckA)
require.NoError(t, err, "failed to save postureCheckA")
postureCheckB := &posture.Checks{
Name: "checkB",
AccountID: accountID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
},
}
postureCheckB, err = manager.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheckB)
require.NoError(t, err, "failed to save postureCheckB")
policy := &Policy{
AccountID: accountID,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
},
},
SourcePostureChecks: []string{postureCheckA.ID},
}
policy, err = manager.SavePolicy(context.Background(), accountID, adminUserID, policy)
require.NoError(t, err, "failed to save policy")
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result) assert.True(t, result)
}) })
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) { t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "checkB", true) result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckB.ID)
require.NoError(t, err)
assert.False(t, result) assert.False(t, result)
}) })
t.Run("posture check does not exist", func(t *testing.T) { t.Run("posture check does not exist", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "unknown", false) result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, "unknown")
require.NoError(t, err)
assert.False(t, result) assert.False(t, result)
}) })
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"groupB"} policy.Rules[0].Sources = []string{"groupB"}
account.Policies[0].Rules[0].Destinations = []string{"groupA"} policy.Rules[0].Destinations = []string{"groupA"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) _, err = manager.SavePolicy(context.Background(), accountID, adminUserID, policy)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result) assert.True(t, result)
}) })
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"groupA"} policy.Rules[0].Sources = []string{"groupA"}
account.Policies[0].Rules[0].Destinations = []string{"groupB"} policy.Rules[0].Destinations = []string{"groupB"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) _, err = manager.SavePolicy(context.Background(), accountID, adminUserID, policy)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result) assert.True(t, result)
}) })
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"} groupA.Peers = []string{}
account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"} err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA)
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) require.NoError(t, err, "failed to save groups")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID)
require.NoError(t, err)
assert.False(t, result) assert.False(t, result)
}) })
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
account.Groups["groupA"].Peers = []string{} policy.Rules[0].Sources = []string{"nonExistentGroup"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) policy.Rules[0].Destinations = []string{"nonExistentGroup"}
_, err = manager.SavePolicy(context.Background(), accountID, adminUserID, policy)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID)
require.NoError(t, err)
assert.False(t, result) assert.False(t, result)
}) })
} }

View File

@@ -9,6 +9,7 @@ import (
"strings" "strings"
"unicode/utf8" "unicode/utf8"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -52,17 +53,46 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") return nil, status.NewUserNotPartOfAccountError()
} }
return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID) if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID))
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func getRoutesByPrefixOrDomains(ctx context.Context, transaction Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
routes := make([]*route.Route, 0)
for _, r := range accountRoutes {
dynamic := r.IsDynamic()
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
!dynamic && r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
return routes, nil
} }
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction Store, accountID string, checkRoute *route.Route, groupsMap map[string]*nbgroup.Group) error {
// routes can have both peer and peer_groups // routes can have both peer and peer_groups
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) prefix := checkRoute.Network
domains := checkRoute.Domains
routesWithPrefix, err := getRoutesByPrefixOrDomains(ctx, transaction, accountID, prefix, domains)
if err != nil {
return err
}
// lets remember all the peers and the peer groups from routesWithPrefix // lets remember all the peers and the peer groups from routesWithPrefix
seenPeers := make(map[string]bool) seenPeers := make(map[string]bool)
@@ -71,18 +101,24 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
for _, prefixRoute := range routesWithPrefix { for _, prefixRoute := range routesWithPrefix {
// we skip route(s) with the same network ID as we want to allow updating of the existing route // we skip route(s) with the same network ID as we want to allow updating of the existing route
// when creating a new route routeID is newly generated so nothing will be skipped // when creating a new route routeID is newly generated so nothing will be skipped
if routeID == prefixRoute.ID { if checkRoute.ID == prefixRoute.ID {
continue continue
} }
if prefixRoute.Peer != "" { if prefixRoute.Peer != "" {
seenPeers[string(prefixRoute.ID)] = true seenPeers[string(prefixRoute.ID)] = true
} }
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, prefixRoute.PeerGroups)
if err != nil {
return err
}
for _, groupID := range prefixRoute.PeerGroups { for _, groupID := range prefixRoute.PeerGroups {
seenPeerGroups[groupID] = true seenPeerGroups[groupID] = true
group := account.GetGroup(groupID) group, ok := peerGroupsMap[groupID]
if group == nil { if !ok || group == nil {
return status.Errorf( return status.Errorf(
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist", status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
getRouteDescriptor(prefix, domains), groupID, getRouteDescriptor(prefix, domains), groupID,
@@ -95,12 +131,13 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
} }
} }
if peerID != "" { if peerID := checkRoute.Peer; peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group // check that peerID exists and is not in any route as single peer or part of the group
peer := account.GetPeer(peerID) _, err = transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID)
if peer == nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
if _, ok := seenPeers[peerID]; ok { if _, ok := seenPeers[peerID]; ok {
return status.Errorf(status.AlreadyExists, return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID) "failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
@@ -108,9 +145,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
} }
// check that peerGroupIDs are not in any route peerGroups list // check that peerGroupIDs are not in any route peerGroups list
for _, groupID := range peerGroupIDs { for _, groupID := range checkRoute.PeerGroups {
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again. group := groupsMap[groupID] // we validated the group existence before entering this function, no need to check again.
if _, ok := seenPeerGroups[groupID]; ok { if _, ok := seenPeerGroups[groupID]; ok {
return status.Errorf( return status.Errorf(
status.AlreadyExists, "failed to add route with %s - peer group %s already has this route", status.AlreadyExists, "failed to add route with %s - peer group %s already has this route",
@@ -118,12 +154,18 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
} }
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
peersMap, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, group.Peers)
if err != nil {
return err
}
for _, id := range group.Peers { for _, id := range group.Peers {
if _, ok := seenPeers[id]; ok { if _, ok := seenPeers[id]; ok {
peer := account.GetPeer(id) peer, ok := peersMap[id]
if peer == nil { if !ok || peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", id)
} }
return status.Errorf(status.AlreadyExists, return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s from the group %s already has this route", "failed to add route with %s - peer %s from the group %s already has this route",
getRouteDescriptor(prefix, domains), peer.Name, group.Name) getRouteDescriptor(prefix, domains), peer.Name, group.Name)
@@ -146,104 +188,63 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Do not allow non-Linux peers if user.AccountID != accountID {
if peer := account.GetPeer(peerID); peer != nil { return nil, status.NewUserNotPartOfAccountError()
if peer.Meta.GoOS != "linux" { }
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
var newRoute *route.Route
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
newRoute = &route.Route{
ID: route.ID(xid.New().String()),
AccountID: accountID,
Network: prefix,
Domains: domains,
KeepRoute: keepRoute,
NetID: netID,
Description: description,
Peer: peerID,
PeerGroups: peerGroupIDs,
NetworkType: networkType,
Masquerade: masquerade,
Metric: metric,
Enabled: enabled,
Groups: groups,
AccessControlGroups: accessControlGroupIDs,
} }
}
if len(domains) > 0 && prefix.IsValid() { if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil {
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") return err
} }
if len(domains) == 0 && !prefix.IsValid() { updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute)
return nil, status.Errorf(status.InvalidArgument, "invalid Prefix")
}
if len(domains) > 0 {
prefix = getPlaceholderIP()
}
if peerID != "" && len(peerGroupIDs) != 0 {
return nil, status.Errorf(
status.InvalidArgument,
"peer with ID %s and peers group %s should not be provided at the same time",
peerID, peerGroupIDs)
}
var newRoute route.Route
newRoute.ID = route.ID(xid.New().String())
if len(peerGroupIDs) > 0 {
err = validateGroups(peerGroupIDs, account.Groups)
if err != nil { if err != nil {
return nil, err return err
} }
}
if len(accessControlGroupIDs) > 0 { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
err = validateGroups(accessControlGroupIDs, account.Groups) return err
if err != nil {
return nil, err
} }
}
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) return transaction.SaveRoute(ctx, LockingStrengthUpdate, newRoute)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if metric < route.MinMetric || metric > route.MaxMetric {
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
}
if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" {
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
err = validateGroups(groups, account.Groups)
if err != nil {
return nil, err
}
newRoute.Peer = peerID
newRoute.PeerGroups = peerGroupIDs
newRoute.Network = prefix
newRoute.Domains = domains
newRoute.NetworkType = networkType
newRoute.Description = description
newRoute.NetID = netID
newRoute.Masquerade = masquerade
newRoute.Metric = metric
newRoute.Enabled = enabled
newRoute.Groups = groups
newRoute.KeepRoute = keepRoute
newRoute.AccessControlGroups = accessControlGroupIDs
if account.Routes == nil {
account.Routes = make(map[route.ID]*route.Route)
}
account.Routes[newRoute.ID] = &newRoute
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
if isRouteChangeAffectPeers(account, &newRoute) {
am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
return &newRoute, nil if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return newRoute, nil
} }
// SaveRoute saves route // SaveRoute saves route
@@ -251,10 +252,151 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
var oldRoute *route.Route
var oldRouteAffectsPeers bool
var newRouteAffectsPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil {
return err
}
oldRoute, err = transaction.GetRouteByID(ctx, LockingStrengthUpdate, accountID, string(routeToSave.ID))
if err != nil {
return err
}
oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute)
if err != nil {
return err
}
newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave)
if err != nil {
return err
}
routeToSave.AccountID = accountID
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveRoute(ctx, LockingStrengthUpdate, routeToSave)
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
var route *route.Route
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
route, err = transaction.GetRouteByID(ctx, LockingStrengthUpdate, accountID, string(routeID))
if err != nil {
return err
}
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeleteRoute(ctx, LockingStrengthUpdate, accountID, string(routeID))
})
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
}
func validateRoute(ctx context.Context, transaction Store, accountID string, routeToSave *route.Route) error {
if routeToSave == nil { if routeToSave == nil {
return status.Errorf(status.InvalidArgument, "route provided is nil") return status.Errorf(status.InvalidArgument, "route provided is nil")
} }
if err := validateRouteProperties(routeToSave); err != nil {
return err
}
if routeToSave.Peer != "" {
peer, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, routeToSave.Peer)
if err != nil {
return err
}
if peer.Meta.GoOS != "linux" {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
groupsMap, err := validateRouteGroups(ctx, transaction, accountID, routeToSave)
if err != nil {
return err
}
return checkRoutePrefixOrDomainsExistForPeers(ctx, transaction, accountID, routeToSave, groupsMap)
}
// Helper to validate route properties.
func validateRouteProperties(routeToSave *route.Route) error {
if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric { if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric {
return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
} }
@@ -263,18 +405,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
// Do not allow non-Linux peers
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
if peer.Meta.GoOS != "linux" {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
} }
@@ -291,89 +421,34 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time") return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
} }
if len(routeToSave.PeerGroups) > 0 {
err = validateGroups(routeToSave.PeerGroups, account.Groups)
if err != nil {
return err
}
}
if len(routeToSave.AccessControlGroups) > 0 {
err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
if err != nil {
return err
}
}
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
if err != nil {
return err
}
err = validateGroups(routeToSave.Groups, account.Groups)
if err != nil {
return err
}
oldRoute := account.Routes[routeToSave.ID]
account.Routes[routeToSave.ID] = routeToSave
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
return nil return nil
} }
// DeleteRoute deletes route with routeID // validateRouteGroups validates the route groups and returns the validated groups map.
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { func validateRouteGroups(ctx context.Context, transaction Store, accountID string, routeToSave *route.Route) (map[string]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
defer unlock() groupsMap, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupsToValidate)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
routy := account.Routes[routeID]
if routy == nil {
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
}
delete(account.Routes, routeID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
if isRouteChangeAffectPeers(account, routy) {
am.updateAccountPeers(ctx, account)
}
return nil
}
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if len(routeToSave.PeerGroups) > 0 {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") if err = validateGroups(routeToSave.PeerGroups, groupsMap); err != nil {
return nil, err
}
} }
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) if len(routeToSave.AccessControlGroups) > 0 {
if err = validateGroups(routeToSave.AccessControlGroups, groupsMap); err != nil {
return nil, err
}
}
if err = validateGroups(routeToSave.Groups, groupsMap); err != nil {
return nil, err
}
return groupsMap, nil
} }
func toProtocolRoute(route *route.Route) *proto.Route { func toProtocolRoute(route *route.Route) *proto.Route {
@@ -649,8 +724,21 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
return &portInfo return &portInfo
} }
// isRouteChangeAffectPeers checks if a given route affects peers by determining // areRouteChangesAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers // if it has a routing peer, distribution, or peer groups that include peers.
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool { func areRouteChangesAffectPeers(ctx context.Context, transaction Store, route *route.Route) (bool, error) {
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" if route.Peer != "" {
return true, nil
}
hasPeers, err := anyGroupHasPeers(ctx, transaction, route.AccountID, route.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeers(ctx, transaction, route.AccountID, route.PeerGroups)
} }

View File

@@ -5,9 +5,11 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strings"
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/management/server/status"
"github.com/rs/xid" "github.com/rs/xid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -427,22 +429,22 @@ func TestCreateRoute(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestRouteAccount(t, am) accountID, err := initTestRouteAccount(t, am)
if err != nil { if err != nil {
t.Errorf("failed to init testing account: %s", err) t.Errorf("failed to init testing account: %s", err)
} }
if testCase.createInitRoute { if testCase.createInitRoute {
groupAll, errInit := account.GetGroupAll() groupAll, errInit := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
require.NoError(t, errInit) require.NoError(t, errInit)
_, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
_, errInit = am.CreateRoute(context.Background(), accountID, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
require.NoError(t, errInit) require.NoError(t, errInit)
_, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false) _, errInit = am.CreateRoute(context.Background(), accountID, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false)
require.NoError(t, errInit) require.NoError(t, errInit)
} }
outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) outRoute, err := am.CreateRoute(context.Background(), accountID, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
testCase.errFunc(t, err) testCase.errFunc(t, err)
if !testCase.shouldCreate { if !testCase.shouldCreate {
@@ -451,6 +453,7 @@ func TestCreateRoute(t *testing.T) {
// assign generated ID // assign generated ID
testCase.expectedRoute.ID = outRoute.ID testCase.expectedRoute.ID = outRoute.ID
testCase.expectedRoute.AccountID = accountID
if !testCase.expectedRoute.IsEqual(outRoute) { if !testCase.expectedRoute.IsEqual(outRoute) {
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", outRoute, testCase.expectedRoute) t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", outRoute, testCase.expectedRoute)
@@ -917,14 +920,15 @@ func TestSaveRoute(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestRouteAccount(t, am) accountID, err := initTestRouteAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
if testCase.createInitRoute { if testCase.createInitRoute {
account.Routes["initRoute"] = &route.Route{ initRoute := &route.Route{
ID: "initRoute", ID: "initRoute",
AccountID: accountID,
Network: existingNetwork, Network: existingNetwork,
NetID: existingRouteID, NetID: existingRouteID,
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
@@ -935,14 +939,13 @@ func TestSaveRoute(t *testing.T) {
Enabled: true, Enabled: true,
Groups: []string{routeGroup1}, Groups: []string{routeGroup1},
} }
err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, initRoute)
require.NoError(t, err, "failed to save init route")
} }
account.Routes[testCase.existingRoute.ID] = testCase.existingRoute testCase.existingRoute.AccountID = accountID
err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, testCase.existingRoute)
err = am.Store.SaveAccount(context.Background(), account) require.NoError(t, err, "failed to save existing route")
if err != nil {
t.Error("account should be saved")
}
var routeToSave *route.Route var routeToSave *route.Route
@@ -977,7 +980,7 @@ func TestSaveRoute(t *testing.T) {
} }
} }
err = am.SaveRoute(context.Background(), account.Id, userID, routeToSave) err = am.SaveRoute(context.Background(), accountID, userID, routeToSave)
testCase.errFunc(t, err) testCase.errFunc(t, err)
@@ -985,14 +988,10 @@ func TestSaveRoute(t *testing.T) {
return return
} }
account, err = am.Store.GetAccount(context.Background(), account.Id) savedRoute, err := am.GetRoute(context.Background(), accountID, testCase.existingRoute.ID, userID)
if err != nil { require.NoError(t, err, "failed to get saved route")
t.Fatal(err)
}
savedRoute, saved := account.Routes[testCase.expectedRoute.ID]
require.True(t, saved)
testCase.expectedRoute.AccountID = accountID
if !testCase.expectedRoute.IsEqual(savedRoute) { if !testCase.expectedRoute.IsEqual(savedRoute) {
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", savedRoute, testCase.expectedRoute) t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", savedRoute, testCase.expectedRoute)
} }
@@ -1001,50 +1000,48 @@ func TestSaveRoute(t *testing.T) {
} }
func TestDeleteRoute(t *testing.T) { func TestDeleteRoute(t *testing.T) {
testingRoute := &route.Route{
ID: "testingRoute",
Network: netip.MustParsePrefix("192.168.0.0/16"),
Domains: domain.List{"domain1", "domain2"},
KeepRoute: true,
NetworkType: route.IPv4Network,
Peer: peer1Key,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
}
am, err := createRouterManager(t) am, err := createRouterManager(t)
if err != nil { if err != nil {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestRouteAccount(t, am) accountID, err := initTestRouteAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
account.Routes[testingRoute.ID] = testingRoute err = am.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "GroupA",
AccountID: accountID,
Name: "GroupA",
})
require.NoError(t, err, "failed to save group")
err = am.Store.SaveAccount(context.Background(), account) testingRoute := &route.Route{
if err != nil { Network: netip.MustParsePrefix("192.168.0.0/16"),
t.Error("failed to save account") NetID: route.NetID("12345678901234567890qw"),
Groups: []string{"GroupA"},
KeepRoute: true,
NetworkType: route.IPv4Network,
Peer: peer1ID,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
} }
createdRoute, err := am.CreateRoute(context.Background(), accountID, testingRoute.Network, testingRoute.NetworkType, testingRoute.Domains, peer1ID, []string{}, testingRoute.Description, testingRoute.NetID, testingRoute.Masquerade, testingRoute.Metric, testingRoute.Groups, testingRoute.AccessControlGroups, true, userID, testingRoute.KeepRoute)
require.NoError(t, err, "failed to create route")
err = am.DeleteRoute(context.Background(), account.Id, testingRoute.ID, userID) err = am.DeleteRoute(context.Background(), accountID, createdRoute.ID, userID)
if err != nil { if err != nil {
t.Error("deleting route failed with error: ", err) t.Error("deleting route failed with error: ", err)
} }
savedAccount, err := am.Store.GetAccount(context.Background(), account.Id) _, err = am.GetRoute(context.Background(), accountID, testingRoute.ID, userID)
if err != nil { require.NotNil(t, err)
t.Error("failed to retrieve saved account with error: ", err) sErr, ok := status.FromError(err)
} require.True(t, ok)
require.Equal(t, status.NotFound, sErr.Type())
_, found := savedAccount.Routes[testingRoute.ID]
if found {
t.Error("route shouldn't be found after delete")
}
} }
func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
@@ -1066,16 +1063,14 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestRouteAccount(t, am) accountID, err := initTestRouteAccount(t, am)
if err != nil { require.NoError(t, err, "failed to init testing account")
t.Error("failed to init testing account")
}
newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute) newRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, newRoute.Enabled, true) require.Equal(t, newRoute.Enabled, true)
@@ -1091,7 +1086,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.ListGroups(context.Background(), account.Id) groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err) require.NoError(t, err)
var groupHA1, groupHA2 *nbgroup.Group var groupHA1, groupHA2 *nbgroup.Group
for _, group := range groups { for _, group := range groups {
@@ -1103,21 +1098,21 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
} }
} }
err = am.GroupDeletePeer(context.Background(), account.Id, groupHA1.ID, peer2ID) err = am.GroupDeletePeer(context.Background(), accountID, groupHA1.ID, peer2ID)
require.NoError(t, err) require.NoError(t, err)
peer2RoutesAfterDelete, err := am.GetNetworkMap(context.Background(), peer2ID) peer2RoutesAfterDelete, err := am.GetNetworkMap(context.Background(), peer2ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes") assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes")
err = am.GroupDeletePeer(context.Background(), account.Id, groupHA2.ID, peer4ID) err = am.GroupDeletePeer(context.Background(), accountID, groupHA2.ID, peer4ID)
require.NoError(t, err) require.NoError(t, err)
peer2RoutesAfterDelete, err = am.GetNetworkMap(context.Background(), peer2ID) peer2RoutesAfterDelete, err = am.GetNetworkMap(context.Background(), peer2ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route") assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route")
err = am.GroupAddPeer(context.Background(), account.Id, groupHA2.ID, peer4ID) err = am.GroupAddPeer(context.Background(), accountID, groupHA2.ID, peer4ID)
require.NoError(t, err) require.NoError(t, err)
peer1RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer1ID) peer1RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1128,7 +1123,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes") assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes")
err = am.DeleteRoute(context.Background(), account.Id, newRoute.ID, userID) err = am.DeleteRoute(context.Background(), accountID, newRoute.ID, userID)
require.NoError(t, err) require.NoError(t, err)
peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID) peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1158,7 +1153,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestRouteAccount(t, am) accountID, err := initTestRouteAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
@@ -1167,7 +1162,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute) createdRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute)
require.NoError(t, err) require.NoError(t, err)
noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1181,7 +1176,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
expectedRoute := enabledRoute.Copy() expectedRoute := enabledRoute.Copy()
expectedRoute.Peer = peer1Key expectedRoute.Peer = peer1Key
err = am.SaveRoute(context.Background(), account.Id, userID, enabledRoute) err = am.SaveRoute(context.Background(), accountID, userID, enabledRoute)
require.NoError(t, err) require.NoError(t, err)
peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID) peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1193,7 +1188,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, peer2Routes.Routes, 0, "no routes for peers not in the distribution group") require.Len(t, peer2Routes.Routes, 0, "no routes for peers not in the distribution group")
err = am.GroupAddPeer(context.Background(), account.Id, routeGroup1, peer2ID) err = am.GroupAddPeer(context.Background(), accountID, routeGroup1, peer2ID)
require.NoError(t, err) require.NoError(t, err)
peer2Routes, err = am.GetNetworkMap(context.Background(), peer2ID) peer2Routes, err = am.GetNetworkMap(context.Background(), peer2ID)
@@ -1206,23 +1201,22 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
Name: "peer1 group", Name: "peer1 group",
Peers: []string{peer1ID}, Peers: []string{peer1ID},
} }
err = am.SaveGroup(context.Background(), account.Id, userID, newGroup) err = am.SaveGroup(context.Background(), accountID, userID, newGroup)
require.NoError(t, err) require.NoError(t, err)
rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser") rules, err := am.ListPolicies(context.Background(), accountID, "testingUser")
require.NoError(t, err) require.NoError(t, err)
defaultRule := rules[0] defaultRule := rules[0]
newPolicy := defaultRule.Copy() newPolicy := defaultRule.Copy()
newPolicy.ID = xid.New().String()
newPolicy.Name = "peer1 only" newPolicy.Name = "peer1 only"
newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Sources = []string{newGroup.ID}
newPolicy.Rules[0].Destinations = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID}
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false) _, err = am.SavePolicy(context.Background(), accountID, userID, newPolicy)
require.NoError(t, err) require.NoError(t, err)
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) err = am.DeletePolicy(context.Background(), accountID, defaultRule.ID, userID)
require.NoError(t, err) require.NoError(t, err)
peer1GroupRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) peer1GroupRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1233,7 +1227,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, peer2GroupRoutes.Routes, 0, "we should not receive routes for peer2") require.Len(t, peer2GroupRoutes.Routes, 0, "we should not receive routes for peer2")
err = am.DeleteRoute(context.Background(), account.Id, enabledRoute.ID, userID) err = am.DeleteRoute(context.Background(), accountID, enabledRoute.ID, userID)
require.NoError(t, err) require.NoError(t, err)
peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID) peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1267,179 +1261,104 @@ func createRouterStore(t *testing.T) (Store, error) {
return store, nil return store, nil
} }
func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
t.Helper() t.Helper()
accountID := "testingAcc" accountID := "testingAcc"
domain := "example.com" domain := "example.com"
account := newAccountWithId(context.Background(), accountID, userID, domain) err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return "", err
} }
ips := account.getTakenIPs() createPeer := func(peerID, peerKey, peerName, dnsLabel, kernel, core, platform, os string) (*nbpeer.Peer, error) {
peer1IP, err := AllocatePeerIP(account.Network.Net, ips) ips, err := am.Store.GetTakenIPs(context.Background(), LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
network, err := am.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peerIP, err := AllocatePeerIP(network.Net, ips)
if err != nil {
return nil, err
}
peer := &nbpeer.Peer{
IP: peerIP,
AccountID: accountID,
ID: peerID,
Key: peerKey,
Name: peerName,
DNSLabel: dnsLabel,
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: peerName,
GoOS: strings.ToLower(kernel),
Kernel: kernel,
Core: core,
Platform: platform,
OS: os,
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
if err := am.Store.AddPeerToAccount(context.Background(), peer); err != nil {
return nil, err
}
return peer, nil
}
// Create peers
peer1, err := createPeer(peer1ID, peer1Key, "test-host1@netbird.io", "test-host1", "Linux", "21.04", "x86_64", "Ubuntu")
if err != nil { if err != nil {
return nil, err return "", err
} }
peer2, err := createPeer(peer2ID, peer2Key, "test-host2@netbird.io", "test-host2", "Linux", "21.04", "x86_64", "Ubuntu")
peer1 := &nbpeer.Peer{
IP: peer1IP,
ID: peer1ID,
Key: peer1Key,
Name: "test-host1@netbird.io",
DNSLabel: "test-host1",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host1@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer1.ID] = peer1
ips = account.getTakenIPs()
peer2IP, err := AllocatePeerIP(account.Network.Net, ips)
if err != nil { if err != nil {
return nil, err return "", err
} }
peer3, err := createPeer(peer3ID, peer3Key, "test-host3@netbird.io", "test-host3", "Darwin", "13.4.1", "arm64", "darwin")
peer2 := &nbpeer.Peer{
IP: peer2IP,
ID: peer2ID,
Key: peer2Key,
Name: "test-host2@netbird.io",
DNSLabel: "test-host2",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host2@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer2.ID] = peer2
ips = account.getTakenIPs()
peer3IP, err := AllocatePeerIP(account.Network.Net, ips)
if err != nil { if err != nil {
return nil, err return "", err
} }
peer4, err := createPeer(peer4ID, peer4Key, "test-host4@netbird.io", "test-host4", "Linux", "21.04", "x86_64", "Ubuntu")
peer3 := &nbpeer.Peer{
IP: peer3IP,
ID: peer3ID,
Key: peer3Key,
Name: "test-host3@netbird.io",
DNSLabel: "test-host3",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host3@netbird.io",
GoOS: "darwin",
Kernel: "Darwin",
Core: "13.4.1",
Platform: "arm64",
OS: "darwin",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer3.ID] = peer3
ips = account.getTakenIPs()
peer4IP, err := AllocatePeerIP(account.Network.Net, ips)
if err != nil { if err != nil {
return nil, err return "", err
} }
peer5, err := createPeer(peer5ID, peer5Key, "test-host5@netbird.io", "test-host5", "Linux", "21.04", "x86_64", "Ubuntu")
peer4 := &nbpeer.Peer{
IP: peer4IP,
ID: peer4ID,
Key: peer4Key,
Name: "test-host4@netbird.io",
DNSLabel: "test-host4",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host4@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer4.ID] = peer4
ips = account.getTakenIPs()
peer5IP, err := AllocatePeerIP(account.Network.Net, ips)
if err != nil { if err != nil {
return nil, err return "", err
} }
peer5 := &nbpeer.Peer{ groupAll, err := am.GetGroupByName(context.Background(), "All", accountID)
IP: peer5IP, if err != nil {
ID: peer5ID, return "", err
Key: peer5Key,
Name: "test-host5@netbird.io",
DNSLabel: "test-host5",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host5@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
} }
account.Peers[peer5.ID] = peer5
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
}
groupAll, err := account.GetGroupAll()
if err != nil {
return nil, err
}
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer1ID) err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer1ID)
if err != nil { if err != nil {
return nil, err return "", err
} }
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer2ID) err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer2ID)
if err != nil { if err != nil {
return nil, err return "", err
} }
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer3ID) err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer3ID)
if err != nil { if err != nil {
return nil, err return "", err
} }
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer4ID) err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer4ID)
if err != nil { if err != nil {
return nil, err return "", err
} }
newGroup := []*nbgroup.Group{ newGroups := []*nbgroup.Group{
{ {
ID: routeGroup1, ID: routeGroup1,
Name: routeGroup1, Name: routeGroup1,
@@ -1471,15 +1390,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
Peers: []string{peer1.ID, peer4.ID}, Peers: []string{peer1.ID, peer4.ID},
}, },
} }
err = am.SaveGroups(context.Background(), accountID, userID, newGroups)
for _, group := range newGroup { if err != nil {
err = am.SaveGroup(context.Background(), accountID, userID, group) return "", err
if err != nil {
return nil, err
}
} }
return am.Store.GetAccount(context.Background(), account.Id) return accountID, nil
} }
func TestAccount_getPeersRoutesFirewall(t *testing.T) { func TestAccount_getPeersRoutesFirewall(t *testing.T) {
@@ -1783,10 +1699,10 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
manager, err := createRouterManager(t) manager, err := createRouterManager(t)
require.NoError(t, err, "failed to create account manager") require.NoError(t, err, "failed to create account manager")
account, err := initTestRouteAccount(t, manager) accountID, err := initTestRouteAccount(t, manager)
require.NoError(t, err, "failed to init testing account") require.NoError(t, err, "failed to init testing account")
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{
{ {
ID: "groupA", ID: "groupA",
Name: "GroupA", Name: "GroupA",
@@ -1832,7 +1748,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
}() }()
_, err := manager.CreateRoute( _, err := manager.CreateRoute(
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer,
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
route.Groups, []string{}, true, userID, route.KeepRoute, route.Groups, []string{}, true, userID, route.KeepRoute,
) )
@@ -1868,7 +1784,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
}() }()
_, err := manager.CreateRoute( _, err := manager.CreateRoute(
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer,
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
route.Groups, []string{}, true, userID, route.KeepRoute, route.Groups, []string{}, true, userID, route.KeepRoute,
) )
@@ -1904,7 +1820,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
}() }()
newRoute, err := manager.CreateRoute( newRoute, err := manager.CreateRoute(
context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer,
baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric,
baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute, baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute,
) )
@@ -1928,7 +1844,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute) err := manager.SaveRoute(context.Background(), accountID, userID, &baseRoute)
require.NoError(t, err) require.NoError(t, err)
select { select {
@@ -1946,7 +1862,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID) err := manager.DeleteRoute(context.Background(), accountID, baseRoute.ID, userID)
require.NoError(t, err) require.NoError(t, err)
select { select {
@@ -1970,7 +1886,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
Groups: []string{routeGroup1}, Groups: []string{routeGroup1},
} }
_, err := manager.CreateRoute( _, err := manager.CreateRoute(
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
) )
@@ -1982,7 +1898,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "groupB", ID: "groupB",
Name: "GroupB", Name: "GroupB",
Peers: []string{peer1ID}, Peers: []string{peer1ID},
@@ -2010,7 +1926,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
Groups: []string{"groupC"}, Groups: []string{"groupC"},
} }
_, err := manager.CreateRoute( _, err := manager.CreateRoute(
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
) )
@@ -2022,7 +1938,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "groupC", ID: "groupC",
Name: "GroupC", Name: "GroupC",
Peers: []string{peer1ID}, Peers: []string{peer1ID},

View File

@@ -449,14 +449,14 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
modifiedGroups := slices.Concat(addedGroups, removedGroups) modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err) log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
return nil return nil
} }
for _, g := range removedGroups { for _, g := range removedGroups {
group, ok := groups[g] group, ok := groups[g]
if !ok { if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err) log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g)
continue continue
} }
@@ -469,7 +469,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
for _, g := range addedGroups { for _, g := range addedGroups {
group, ok := groups[g] group, ok := groups[g]
if !ok { if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err) log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g)
continue continue
} }

View File

@@ -25,12 +25,12 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
} }
userID := "testingUser" userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{
{ {
ID: "group_1", ID: "group_1",
Name: "group_name_1", Name: "group_name_1",
@@ -49,7 +49,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
expiresIn := time.Hour expiresIn := time.Hour
keyName := "my-test-key" keyName := "my-test-key"
key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{}, key, err := manager.CreateSetupKey(context.Background(), accountID, keyName, SetupKeyReusable, expiresIn, []string{},
SetupKeyUnlimitedUsage, userID, false) SetupKeyUnlimitedUsage, userID, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -58,7 +58,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
autoGroups := []string{"group_1", "group_2"} autoGroups := []string{"group_1", "group_2"}
newKeyName := "my-new-test-key" newKeyName := "my-new-test-key"
revoked := true revoked := true
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ newKey, err := manager.SaveSetupKey(context.Background(), accountID, &SetupKey{
Id: key.Id, Id: key.Id,
Name: newKeyName, Name: newKeyName,
Revoked: revoked, Revoked: revoked,
@@ -72,22 +72,22 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
key.Id, time.Now().UTC(), autoGroups, true) key.Id, time.Now().UTC(), autoGroups, true)
// check the corresponding events that should have been generated // check the corresponding events that should have been generated
ev := getEvent(t, account.Id, manager, activity.SetupKeyRevoked) ev := getEvent(t, accountID, manager, activity.SetupKeyRevoked)
assert.NotNil(t, ev) assert.NotNil(t, ev)
assert.Equal(t, account.Id, ev.AccountID) assert.Equal(t, accountID, ev.AccountID)
assert.Equal(t, newKeyName, ev.Meta["name"]) assert.Equal(t, newKeyName, ev.Meta["name"])
assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"])) assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"]))
assert.NotEmpty(t, ev.Meta["key"]) assert.NotEmpty(t, ev.Meta["key"])
assert.Equal(t, userID, ev.InitiatorID) assert.Equal(t, userID, ev.InitiatorID)
assert.Equal(t, key.Id, ev.TargetID) assert.Equal(t, key.Id, ev.TargetID)
groupAll, err := account.GetGroupAll() groupAll, err := manager.GetGroupByName(context.Background(), "All", accountID)
assert.NoError(t, err) assert.NoError(t, err)
// saving setup key with All group assigned to auto groups should return error // saving setup key with All group assigned to auto groups should return error
autoGroups = append(autoGroups, groupAll.ID) autoGroups = append(autoGroups, groupAll.ID)
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ _, err = manager.SaveSetupKey(context.Background(), accountID, &SetupKey{
Id: key.Id, Id: key.Id,
Name: newKeyName, Name: newKeyName,
Revoked: revoked, Revoked: revoked,
@@ -103,12 +103,12 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
} }
userID := "testingUser" userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "group_1", ID: "group_1",
Name: "group_name_1", Name: "group_name_1",
Peers: []string{}, Peers: []string{},
@@ -117,7 +117,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "group_2", ID: "group_2",
Name: "group_name_2", Name: "group_name_2",
Peers: []string{}, Peers: []string{},
@@ -126,7 +126,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
groupAll, err := account.GetGroupAll() groupAll, err := manager.GetGroupByName(context.Background(), "All", accountID)
assert.NoError(t, err) assert.NoError(t, err)
type testCase struct { type testCase struct {
@@ -170,7 +170,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
for _, tCase := range []testCase{testCase1, testCase2, testCase3} { for _, tCase := range []testCase{testCase1, testCase2, testCase3} {
t.Run(tCase.name, func(t *testing.T) { t.Run(tCase.name, func(t *testing.T) {
key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, key, err := manager.CreateSetupKey(context.Background(), accountID, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false) tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false)
if tCase.expectedFailure { if tCase.expectedFailure {
@@ -189,10 +189,10 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
tCase.expectedUpdatedAt, tCase.expectedGroups, false) tCase.expectedUpdatedAt, tCase.expectedGroups, false)
// check the corresponding events that should have been generated // check the corresponding events that should have been generated
ev := getEvent(t, account.Id, manager, activity.SetupKeyCreated) ev := getEvent(t, accountID, manager, activity.SetupKeyCreated)
assert.NotNil(t, ev) assert.NotNil(t, ev)
assert.Equal(t, account.Id, ev.AccountID) assert.Equal(t, accountID, ev.AccountID)
assert.Equal(t, tCase.expectedKeyName, ev.Meta["name"]) assert.Equal(t, tCase.expectedKeyName, ev.Meta["name"])
assert.Equal(t, tCase.expectedType, fmt.Sprint(ev.Meta["type"])) assert.Equal(t, tCase.expectedType, fmt.Sprint(ev.Meta["type"]))
assert.NotEmpty(t, ev.Meta["key"]) assert.NotEmpty(t, ev.Meta["key"])
@@ -208,12 +208,12 @@ func TestGetSetupKeys(t *testing.T) {
} }
userID := "testingUser" userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "group_1", ID: "group_1",
Name: "group_name_1", Name: "group_name_1",
Peers: []string{}, Peers: []string{},
@@ -222,7 +222,7 @@ func TestGetSetupKeys(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "group_2", ID: "group_2",
Name: "group_name_2", Name: "group_name_2",
Peers: []string{}, Peers: []string{},
@@ -390,8 +390,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
policy := Policy{ policy := &Policy{
ID: "policy",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -403,7 +402,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
}, },
}, },
} }
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
require.NoError(t, err) require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,6 @@ package status
import ( import (
"errors" "errors"
"fmt" "fmt"
"time"
) )
const ( const (
@@ -87,9 +86,14 @@ func NewAccountNotFoundError(accountKey string) error {
return Errorf(NotFound, "account not found: %s", accountKey) return Errorf(NotFound, "account not found: %s", accountKey)
} }
// NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account
func NewPeerNotPartOfAccountError() error {
return Errorf(PermissionDenied, "peer is not part of this account")
}
// NewUserNotFoundError creates a new Error with NotFound type for a missing user // NewUserNotFoundError creates a new Error with NotFound type for a missing user
func NewUserNotFoundError(userKey string) error { func NewUserNotFoundError(userKey string) error {
return Errorf(NotFound, "user not found: %s", userKey) return Errorf(NotFound, "user: %s not found", userKey)
} }
// NewPeerNotRegisteredError creates a new Error with NotFound type for a missing peer // NewPeerNotRegisteredError creates a new Error with NotFound type for a missing peer
@@ -126,11 +130,6 @@ func NewAdminPermissionError() error {
return Errorf(PermissionDenied, "admin role required to perform this action") return Errorf(PermissionDenied, "admin role required to perform this action")
} }
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
func NewStoreContextCanceledError(duration time.Duration) error {
return Errorf(Internal, "store access: context canceled after %v", duration)
}
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key // NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
func NewInvalidKeyIDError() error { func NewInvalidKeyIDError() error {
return Errorf(InvalidArgument, "invalid key ID") return Errorf(InvalidArgument, "invalid key ID")
@@ -140,3 +139,42 @@ func NewInvalidKeyIDError() error {
func NewGetAccountError(err error) error { func NewGetAccountError(err error) error {
return Errorf(Internal, "error getting account: %s", err) return Errorf(Internal, "error getting account: %s", err)
} }
// NewGroupNotFoundError creates a new Error with NotFound type for a missing group
func NewGroupNotFoundError(groupID string) error {
return Errorf(NotFound, "group: %s not found", groupID)
}
// NewPostureChecksNotFoundError creates a new Error with NotFound type for a missing posture checks
func NewPostureChecksNotFoundError(postureChecksID string) error {
return Errorf(NotFound, "posture checks: %s not found", postureChecksID)
}
// NewPolicyNotFoundError creates a new Error with NotFound type for a missing policy
func NewPolicyNotFoundError(policyID string) error {
return Errorf(NotFound, "policy: %s not found", policyID)
}
// NewNameServerGroupNotFoundError creates a new Error with NotFound type for a missing name server group
func NewNameServerGroupNotFoundError(nsGroupID string) error {
return Errorf(NotFound, "nameserver group: %s not found", nsGroupID)
}
// NewServiceUserRoleInvalidError creates a new Error with InvalidArgument type for creating a service user with owner role
func NewServiceUserRoleInvalidError() error {
return Errorf(InvalidArgument, "can't create a service user with owner role")
}
// NewOwnerDeletePermissionError creates a new Error with PermissionDenied type for attempting
// to delete a user with the owner role.
func NewOwnerDeletePermissionError() error {
return Errorf(PermissionDenied, "can't delete a user with the owner role")
}
func NewPATNotFoundError(patID string) error {
return Errorf(NotFound, "PAT: %s not found", patID)
}
func NewRouteNotFoundError(routeID string) error {
return Errorf(NotFound, "route: %s not found", routeID)
}

View File

@@ -48,51 +48,78 @@ type Store interface {
GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByUser(ctx context.Context, userID string) (*Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByUserID(userID string) (string, error) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error)
GetTotalAccounts(ctx context.Context) (int64, error)
SaveAccount(ctx context.Context, account *Account) error SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain *bool) error
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error
CreateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error)
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error)
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error)
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error)
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error)
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
SavePeerLocation(accountID string, peer *nbpeer.Peer) error GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
@@ -102,13 +129,17 @@ type Store interface {
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, accountId string) error IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
GetInstallationID() string GetInstallationID() string

View File

@@ -34,4 +34,8 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003'
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,'');
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,'');
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}');
INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0);
INSERT INTO routes VALUES('ct03t427qv97vmtmglog','bf1c8084-ba50-4ce7-9439-34653001fc3b','"10.10.0.0/16"',NULL,0,'aws-eu-central-1-vpc','Production VPC in Frankfurt','ct03r5q7qv97vmtmglng',NULL,1,1,9999,1,'["cfefqs706sqkneg59g2g"]',NULL);
INSERT INTO installations VALUES(1,''); INSERT INTO installations VALUES(1,'');

View File

@@ -25,7 +25,7 @@ CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cs1tnh0hhcjnqoiuebeg"]',0,0); INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cs1tnh0hhcjnqoiuebeg"]',0,0);
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,'');

View File

@@ -1,6 +1,6 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`inactivity_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
@@ -27,9 +27,13 @@ CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 17:00:32.527528+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,3600000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 17:00:32.527528+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,3600000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfvprsrlo1hqoo49ohog", "cg3161rlo1hs9cq94gdg", "cg05lnblo1hkg2j514p0"]',0,'');
INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]');
INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL);
INSERT INTO installations VALUES(1,''); INSERT INTO installations VALUES(1,'');

View File

@@ -34,6 +34,6 @@ INSERT INTO setup_keys VALUES('3504804807','google-oauth2|103201118415301331038'
INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0);
INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0);
INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0); INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0);
INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0); INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',1,'""','','',0);
INSERT INTO installations VALUES(1,''); INSERT INTO installations VALUES(1,'');

File diff suppressed because it is too large Load Diff

View File

@@ -43,37 +43,34 @@ const (
func TestUser_CreatePAT_ForSameUser(t *testing.T) { func TestUser_CreatePAT_ForSameUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, eventStore: &activity.InMemoryEventStore{},
} }
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) newPAT, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
if err != nil { if err != nil {
t.Fatalf("Error when adding PAT to user: %s", err) t.Fatalf("Error when adding PAT to user: %s", err)
} }
assert.Equal(t, pat.CreatedBy, mockUserID) assert.Equal(t, newPAT.CreatedBy, mockUserID)
tokenID, err := am.Store.GetTokenIDByHashedToken(context.Background(), pat.HashedToken) pat, err := am.Store.GetPATByHashedToken(context.Background(), LockingStrengthShare, newPAT.HashedToken)
if err != nil { if err != nil {
t.Fatalf("Error when getting token ID by hashed token: %s", err) t.Fatalf("Error when getting token ID by hashed token: %s", err)
} }
if tokenID == "" { if pat.ID == "" {
t.Fatal("GetTokenIDByHashedToken failed after adding PAT") t.Fatal("GetTokenIDByHashedToken failed after adding PAT")
} }
assert.Equal(t, pat.ID, tokenID) assert.Equal(t, newPAT.ID, pat.ID)
user, err := am.Store.GetUserByTokenID(context.Background(), tokenID) user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, pat.ID)
if err != nil { if err != nil {
t.Fatalf("Error when getting user by token ID: %s", err) t.Fatalf("Error when getting user by token ID: %s", err)
} }
@@ -84,15 +81,16 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockTargetUserId] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: mockTargetUserId, Id: mockTargetUserId,
AccountID: mockAccountID,
IsServiceUser: false, IsServiceUser: false,
} })
err := store.SaveAccount(context.Background(), account) assert.NoError(t, err, "failed to create user")
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -106,15 +104,16 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
func TestUser_CreatePAT_ForServiceUser(t *testing.T) { func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockTargetUserId] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: mockTargetUserId, Id: mockTargetUserId,
AccountID: mockAccountID,
IsServiceUser: true, IsServiceUser: true,
} })
err := store.SaveAccount(context.Background(), account) assert.NoError(t, err, "failed to create user")
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -132,12 +131,9 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -151,12 +147,9 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
func TestUser_CreatePAT_WithEmptyName(t *testing.T) { func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -164,26 +157,22 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
} }
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn) _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn)
assert.Errorf(t, err, "Wrong expiration should thorw error") assert.Errorf(t, err, "Wrong expiration should throw error")
} }
func TestUser_DeletePAT(t *testing.T) { func TestUser_DeletePAT(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
Id: mockUserID, assert.NoError(t, err, "failed to create account")
PATs: map[string]*PersonalAccessToken{
mockTokenID1: { err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
ID: mockTokenID1, ID: mockTokenID1,
HashedToken: mockToken1, UserID: mockUserID,
}, HashedToken: mockToken1,
}, })
} assert.NoError(t, err, "failed to create PAT")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -195,7 +184,7 @@ func TestUser_DeletePAT(t *testing.T) {
t.Fatalf("Error when adding PAT to user: %s", err) t.Fatalf("Error when adding PAT to user: %s", err)
} }
account, err = store.GetAccount(context.Background(), mockAccountID) account, err := store.GetAccount(context.Background(), mockAccountID)
if err != nil { if err != nil {
t.Fatalf("Error when getting account: %s", err) t.Fatalf("Error when getting account: %s", err)
} }
@@ -206,21 +195,16 @@ func TestUser_DeletePAT(t *testing.T) {
func TestUser_GetPAT(t *testing.T) { func TestUser_GetPAT(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
Id: mockUserID, assert.NoError(t, err, "failed to create account")
AccountID: mockAccountID,
PATs: map[string]*PersonalAccessToken{ err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
mockTokenID1: { ID: mockTokenID1,
ID: mockTokenID1, UserID: mockUserID,
HashedToken: mockToken1, HashedToken: mockToken1,
}, })
}, assert.NoError(t, err, "failed to create PAT")
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -239,25 +223,23 @@ func TestUser_GetPAT(t *testing.T) {
func TestUser_GetAllPATs(t *testing.T) { func TestUser_GetAllPATs(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
Id: mockUserID, assert.NoError(t, err, "failed to create account")
AccountID: mockAccountID,
PATs: map[string]*PersonalAccessToken{ err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
mockTokenID1: { ID: mockTokenID1,
ID: mockTokenID1, UserID: mockUserID,
HashedToken: mockToken1, HashedToken: mockToken1,
}, })
mockTokenID2: { assert.NoError(t, err, "failed to create PAT")
ID: mockTokenID2,
HashedToken: mockToken2, err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
}, ID: mockTokenID2,
}, UserID: mockUserID,
} HashedToken: mockToken2,
err := store.SaveAccount(context.Background(), account) })
if err != nil { assert.NoError(t, err, "failed to create PAT")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -342,12 +324,9 @@ func validateStruct(s interface{}) (err error) {
func TestUser_CreateServiceUser(t *testing.T) { func TestUser_CreateServiceUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -359,7 +338,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
t.Fatalf("Error when creating service user: %s", err) t.Fatalf("Error when creating service user: %s", err)
} }
account, err = store.GetAccount(context.Background(), mockAccountID) account, err := store.GetAccount(context.Background(), mockAccountID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 2, len(account.Users)) assert.Equal(t, 2, len(account.Users))
@@ -383,12 +362,9 @@ func TestUser_CreateServiceUser(t *testing.T) {
func TestUser_CreateUser_ServiceUser(t *testing.T) { func TestUser_CreateUser_ServiceUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -406,7 +382,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
t.Fatalf("Error when creating user: %s", err) t.Fatalf("Error when creating user: %s", err)
} }
account, err = store.GetAccount(context.Background(), mockAccountID) account, err := store.GetAccount(context.Background(), mockAccountID)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, user.IsServiceUser) assert.True(t, user.IsServiceUser)
@@ -425,12 +401,9 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
func TestUser_CreateUser_RegularUser(t *testing.T) { func TestUser_CreateUser_RegularUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -450,12 +423,9 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
func TestUser_InviteNewUser(t *testing.T) { func TestUser_InviteNewUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -549,13 +519,13 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = tt.serviceUser
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
} tt.serviceUser.AccountID = mockAccountID
err = store.SaveUser(context.Background(), LockingStrengthUpdate, tt.serviceUser)
assert.NoError(t, err, "failed to create service user")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -582,12 +552,9 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
func TestUser_DeleteUser_SelfDelete(t *testing.T) { func TestUser_DeleteUser_SelfDelete(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -603,39 +570,38 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
func TestUser_DeleteUser_regularUser(t *testing.T) { func TestUser_DeleteUser_regularUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
targetId := "user2" err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
account.Users[targetId] = &User{ assert.NoError(t, err, "failed to create account")
Id: targetId,
IsServiceUser: true,
ServiceUserName: "user2username",
}
targetId = "user3"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
}
targetId = "user4"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedIntegration,
}
targetId = "user5" err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
account.Users[targetId] = &User{ {
Id: targetId, Id: "user2",
IsServiceUser: false, AccountID: mockAccountID,
Issued: UserIssuedAPI, IsServiceUser: true,
Role: UserRoleOwner, ServiceUserName: "user2username",
} },
{
err := store.SaveAccount(context.Background(), account) Id: "user3",
if err != nil { AccountID: mockAccountID,
t.Fatalf("Error when saving account: %s", err) IsServiceUser: false,
} Issued: UserIssuedAPI,
},
{
Id: "user4",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedIntegration,
},
{
Id: "user5",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleOwner,
},
})
assert.NoError(t, err, "failed to save users")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -685,61 +651,64 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
func TestUser_DeleteUser_RegularUsers(t *testing.T) { func TestUser_DeleteUser_RegularUsers(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
targetId := "user2" err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
account.Users[targetId] = &User{ assert.NoError(t, err, "failed to create account")
Id: targetId,
IsServiceUser: true,
ServiceUserName: "user2username",
}
targetId = "user3"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
}
targetId = "user4"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedIntegration,
}
targetId = "user5" err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
account.Users[targetId] = &User{ {
Id: targetId, Id: "user2",
IsServiceUser: false, AccountID: mockAccountID,
Issued: UserIssuedAPI, IsServiceUser: true,
Role: UserRoleOwner, ServiceUserName: "user2username",
} },
account.Users["user6"] = &User{ {
Id: "user6", Id: "user3",
IsServiceUser: false, AccountID: mockAccountID,
Issued: UserIssuedAPI, IsServiceUser: false,
} Issued: UserIssuedAPI,
account.Users["user7"] = &User{ },
Id: "user7", {
IsServiceUser: false, Id: "user4",
Issued: UserIssuedAPI, AccountID: mockAccountID,
} IsServiceUser: false,
account.Users["user8"] = &User{ Issued: UserIssuedIntegration,
Id: "user8", },
IsServiceUser: false, {
Issued: UserIssuedAPI, Id: "user5",
Role: UserRoleAdmin, AccountID: mockAccountID,
} IsServiceUser: false,
account.Users["user9"] = &User{ Issued: UserIssuedAPI,
Id: "user9", Role: UserRoleOwner,
IsServiceUser: false, },
Issued: UserIssuedAPI, {
Role: UserRoleAdmin, Id: "user6",
} AccountID: mockAccountID,
IsServiceUser: false,
err := store.SaveAccount(context.Background(), account) Issued: UserIssuedAPI,
if err != nil { },
t.Fatalf("Error when saving account: %s", err) {
} Id: "user7",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
},
{
Id: "user8",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
},
{
Id: "user9",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
},
})
assert.NoError(t, err)
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -786,7 +755,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
{ {
name: "Delete non-existent user", name: "Delete non-existent user",
userIDs: []string{"non-existent-user"}, userIDs: []string{"non-existent-user"},
expectedReasons: []string{"target user: non-existent-user not found"}, expectedReasons: []string{"user: non-existent-user not found"},
expectedNotDeleted: []string{}, expectedNotDeleted: []string{},
}, },
{ {
@@ -816,7 +785,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
acc, err := am.Store.GetAccount(context.Background(), account.Id) acc, err := am.Store.GetAccount(context.Background(), mockAccountID)
assert.NoError(t, err) assert.NoError(t, err)
for _, id := range tc.expectedDeleted { for _, id := range tc.expectedDeleted {
@@ -836,12 +805,9 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
func TestDefaultAccountManager_GetUser(t *testing.T) { func TestDefaultAccountManager_GetUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -865,14 +831,19 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
func TestDefaultAccountManager_ListUsers(t *testing.T) { func TestDefaultAccountManager_ListUsers(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewRegularUser("normal_user1")
account.Users["normal_user2"] = NewRegularUser("normal_user2")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
} newUser := NewRegularUser("normal_user1")
newUser.AccountID = mockAccountID
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
assert.NoError(t, err, "failed to create user")
newUser = NewRegularUser("normal_user2")
newUser.AccountID = mockAccountID
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
assert.NoError(t, err, "failed to create user")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -946,15 +917,25 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
delete(account.Users, mockUserID)
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
} newUser := NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
newUser.AccountID = mockAccountID
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
assert.NoError(t, err, "failed to create user")
settings, err := store.GetAccountSettings(context.Background(), LockingStrengthShare, mockAccountID)
assert.NoError(t, err, "failed to get account settings")
settings.RegularUsersViewBlocked = testCase.limitedViewSettings
err = store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, mockAccountID, settings)
assert.NoError(t, err, "failed to save account settings")
err = store.DeleteUser(context.Background(), LockingStrengthUpdate, mockAccountID, mockUserID)
assert.NoError(t, err, "failed to delete user")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -968,7 +949,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
assert.Equal(t, 1, len(users)) assert.Equal(t, 1, len(users))
userInfo, _ := users[0].ToUserInfo(nil, account.Settings) userInfo, _ := users[0].ToUserInfo(nil, settings)
assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView) assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView)
}) })
} }
@@ -978,22 +959,21 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
func TestDefaultAccountManager_ExternalCache(t *testing.T) { func TestDefaultAccountManager_ExternalCache(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
externalUser := &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
Id: "externalUser", assert.NoError(t, err, "failed to create account")
Role: UserRoleUser,
Issued: UserIssuedIntegration, err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: "externalUser",
AccountID: mockAccountID,
Role: UserRoleUser,
Issued: UserIssuedIntegration,
IntegrationReference: integration_reference.IntegrationReference{ IntegrationReference: integration_reference.IntegrationReference{
ID: 1, ID: 1,
IntegrationType: "external", IntegrationType: "external",
}, },
} })
account.Users[externalUser.Id] = externalUser assert.NoError(t, err, "failed to create user")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -1013,6 +993,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
cacheManager := am.GetExternalCacheManager() cacheManager := am.GetExternalCacheManager()
externalUser, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, "externalUser")
assert.NoError(t, err, "failed to get user")
cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id) cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}) err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"})
assert.NoError(t, err) assert.NoError(t, err)
@@ -1042,17 +1026,17 @@ func TestUser_IsAdmin(t *testing.T) {
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: mockServiceUserID, Id: mockServiceUserID,
AccountID: mockAccountID,
Role: "user", Role: "user",
IsServiceUser: true, IsServiceUser: true,
} })
assert.NoError(t, err, "failed to create user")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -1071,17 +1055,16 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{ assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: mockServiceUserID, Id: mockServiceUserID,
AccountID: mockAccountID,
Role: "user", Role: "user",
IsServiceUser: true, IsServiceUser: true,
} })
assert.NoError(t, err, "failed to create user")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -1240,21 +1223,30 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
// create an account and an admin user // create an account and an admin user
account, err := manager.GetOrCreateAccountByUser(context.Background(), ownerUserID, "netbird.io") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), ownerUserID, "netbird.io")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// create other users // create other users
account.Users[regularUserID] = NewRegularUser(regularUserID) regularUser := NewRegularUser(regularUserID)
account.Users[adminUserID] = NewAdminUser(adminUserID) regularUser.AccountID = accountID
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
err = manager.Store.SaveAccount(context.Background(), account) adminUser := NewAdminUser(adminUserID)
if err != nil { adminUser.AccountID = accountID
t.Fatal(err)
serviceUser := &User{
Id: serviceUserID,
AccountID: accountID,
IsServiceUser: true,
Role: UserRoleAdmin,
ServiceUserName: "service",
} }
updated, err := manager.SaveUser(context.Background(), account.Id, tc.initiatorID, tc.update) err = manager.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{regularUser, adminUser, serviceUser})
assert.NoError(t, err, "failed to save users")
updated, err := manager.SaveUser(context.Background(), accountID, tc.initiatorID, tc.update)
if tc.expectedErr { if tc.expectedErr {
require.Errorf(t, err, "expecting SaveUser to throw an error") require.Errorf(t, err, "expecting SaveUser to throw an error")
} else { } else {
@@ -1279,8 +1271,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
policy := Policy{ policy := &Policy{
ID: "policy",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -1292,7 +1283,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
}, },
}, },
} }
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
require.NoError(t, err) require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)

View File

@@ -88,18 +88,18 @@ type Route struct {
// AccountID is a reference to Account that this object belongs // AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
// Network and Domains are mutually exclusive // Network and Domains are mutually exclusive
Network netip.Prefix `gorm:"serializer:json"` Network netip.Prefix `gorm:"serializer:json"`
Domains domain.List `gorm:"serializer:json"` Domains domain.List `gorm:"serializer:json"`
KeepRoute bool KeepRoute bool
NetID NetID NetID NetID
Description string Description string
Peer string Peer string
PeerGroups []string `gorm:"serializer:json"` PeerGroups []string `gorm:"serializer:json"`
NetworkType NetworkType NetworkType NetworkType
Masquerade bool Masquerade bool
Metric int Metric int
Enabled bool Enabled bool
Groups []string `gorm:"serializer:json"` Groups []string `gorm:"serializer:json"`
AccessControlGroups []string `gorm:"serializer:json"` AccessControlGroups []string `gorm:"serializer:json"`
} }
@@ -111,19 +111,20 @@ func (r *Route) EventMeta() map[string]any {
// Copy copies a route object // Copy copies a route object
func (r *Route) Copy() *Route { func (r *Route) Copy() *Route {
route := &Route{ route := &Route{
ID: r.ID, ID: r.ID,
Description: r.Description, AccountID: r.AccountID,
NetID: r.NetID, Description: r.Description,
Network: r.Network, NetID: r.NetID,
Domains: slices.Clone(r.Domains), Network: r.Network,
KeepRoute: r.KeepRoute, Domains: slices.Clone(r.Domains),
NetworkType: r.NetworkType, KeepRoute: r.KeepRoute,
Peer: r.Peer, NetworkType: r.NetworkType,
PeerGroups: slices.Clone(r.PeerGroups), Peer: r.Peer,
Metric: r.Metric, PeerGroups: slices.Clone(r.PeerGroups),
Masquerade: r.Masquerade, Metric: r.Metric,
Enabled: r.Enabled, Masquerade: r.Masquerade,
Groups: slices.Clone(r.Groups), Enabled: r.Enabled,
Groups: slices.Clone(r.Groups),
AccessControlGroups: slices.Clone(r.AccessControlGroups), AccessControlGroups: slices.Clone(r.AccessControlGroups),
} }
return route return route
@@ -149,7 +150,7 @@ func (r *Route) IsEqual(other *Route) bool {
other.Masquerade == r.Masquerade && other.Masquerade == r.Masquerade &&
other.Enabled == r.Enabled && other.Enabled == r.Enabled &&
slices.Equal(r.Groups, other.Groups) && slices.Equal(r.Groups, other.Groups) &&
slices.Equal(r.PeerGroups, other.PeerGroups)&& slices.Equal(r.PeerGroups, other.PeerGroups) &&
slices.Equal(r.AccessControlGroups, other.AccessControlGroups) slices.Equal(r.AccessControlGroups, other.AccessControlGroups)
} }