Compare commits

...

78 Commits

Author SHA1 Message Date
bcmmbaga
a23a09bba3 Fix failed to create policy and delete user PAT on postgres
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-07 18:34:07 +03:00
bcmmbaga
2f7027194b Remove code duplicate on peer
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-07 12:56:54 +03:00
bcmmbaga
197d844a16 fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-07 11:39:24 +03:00
bcmmbaga
df6c9a528a Refactor UpdatePeer method to defer event logging and scheduling until after peer save
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-06 16:01:07 +03:00
bcmmbaga
9cb7336ef5 fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-06 15:59:12 +03:00
bcmmbaga
e513e51e9f Handle new account creation directly within the store
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-04 14:43:14 +03:00
bcmmbaga
4ad00e784c Remove redundant accounts All group check on startup
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-03 18:49:16 +03:00
bcmmbaga
bfeb7f0875 Refactor users updating
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-03 01:14:26 +03:00
bcmmbaga
dde01b8e02 Refactor user and peers delete
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-02 16:05:12 +03:00
bcmmbaga
74246d18ba Merge branch 'main' into refactor/get-account-usage
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-02 06:41:34 +03:00
bcmmbaga
fa5db7d7ee Refactor service user handling, user cache lookup, and cache loading
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-01 20:27:52 +03:00
bcmmbaga
fed48de83f Refactor auth middleware
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-01 14:02:09 +03:00
bcmmbaga
e73b5da42b Refactor update account peers
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-31 22:30:13 +03:00
bcmmbaga
8cacdae70c Merge branch 'main' into refactor/get-account-usage
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-31 21:59:09 +03:00
bcmmbaga
6b94f6e4e7 Refactor ephemeral peers and mark PAT as used
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-31 21:50:05 +03:00
bcmmbaga
b7525d9fe8 Merge branch 'main' into refactor/get-account-usage 2024-10-30 22:36:47 +03:00
bcmmbaga
901d283114 Merge branch 'main' into refactor-get-account-by-token
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-30 22:34:59 +03:00
bcmmbaga
7278a21b0d refactor get account in peers
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-29 13:50:44 +03:00
bcmmbaga
9bf0bf4843 wip: refactor get account in peers
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-28 17:47:54 +03:00
bcmmbaga
313e158e20 Refactor route
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-25 13:18:24 +03:00
bcmmbaga
0bdcb41e20 Refactor peer expiry, inactivity, location and status update to remove get account
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-23 19:03:48 +03:00
bcmmbaga
97dbdd7940 fix group tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-18 10:48:28 +03:00
bcmmbaga
a82b5ce80e Merge branch 'main' into refactor/get-account-usage
# Conflicts:
#	management/server/account.go
2024-10-17 22:01:26 +03:00
bcmmbaga
83be99c849 refactor get peers posture checks
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-17 21:58:34 +03:00
bcmmbaga
ee96a81b83 fix handler tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-17 16:34:44 +03:00
bcmmbaga
b0edc5f1f7 Merge branch 'main' into refactor/get-account-usage
# Conflicts:
#	management/server/sql_store.go
2024-10-17 16:10:16 +03:00
bcmmbaga
408d0cd504 Refactor policy save and delete
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-17 14:11:22 +03:00
bcmmbaga
b66f331711 get the first element when get record by ID
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-17 14:10:01 +03:00
bcmmbaga
d7a6996bed check user accounts for setup keys
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-17 11:59:46 +03:00
bcmmbaga
d7c63d5c04 Remove get account from groups ops
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-16 16:04:34 +03:00
bcmmbaga
1123729c1c fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-15 18:17:47 +03:00
bcmmbaga
a8c8b77df8 Merge branch 'main' into refactor/get-account-usage
# Conflicts:
#	management/server/account.go
#	management/server/file_store.go
#	management/server/peer.go
#	management/server/policy.go
#	management/server/route.go
#	management/server/sql_store.go
#	management/server/store.go
#	management/server/user.go
2024-10-14 14:31:55 +03:00
bcmmbaga
0297b5f142 wip: refactoring
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-02 11:56:47 +03:00
bcmmbaga
78e238646c refactor groups methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-01 16:32:31 +03:00
bcmmbaga
f9ed25f8b1 wip refactor peer methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-01 01:07:48 +03:00
bcmmbaga
f43a006c34 Fix posture check name uniqueness per account
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-01 01:06:52 +03:00
bcmmbaga
1a37b12d1b refactor user PAT
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-01 00:55:32 +03:00
bcmmbaga
d36d30dec4 refactor name server groups
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-01 00:54:53 +03:00
bcmmbaga
43eb7261e3 refactor account and dns settings
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-01 00:54:28 +03:00
bcmmbaga
9e47c94a7f refactor setup keys
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-30 14:02:55 +03:00
bcmmbaga
edf67672ad fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-27 22:31:26 +03:00
bcmmbaga
bc520412ba Merge branch 'main' into refactor/get-account-usage
# Conflicts:
#	management/server/file_store.go
#	management/server/http/posture_checks_handler.go
#	management/server/mock_server/account_mock.go
#	management/server/policy.go
#	management/server/sql_store.go
#	management/server/store.go
2024-09-27 20:27:05 +03:00
bcmmbaga
d87fe0257b Merge branch 'refactor-get-account-by-token' into refactor/get-account-usage 2024-09-26 19:48:17 +03:00
bcmmbaga
b1b2b0adf0 fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-26 19:47:43 +03:00
bcmmbaga
96f18c2c8c fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-26 19:46:37 +03:00
bcmmbaga
73be8c8a32 fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-26 19:00:59 +03:00
bcmmbaga
f61c914fd7 Merge branch 'refactor-get-account-by-token' into refactor/get-account-usage
# Conflicts:
#	management/server/file_store.go
2024-09-26 18:51:47 +03:00
bcmmbaga
4575ae2841 add store lock
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-26 18:46:23 +03:00
bcmmbaga
ca6a9fd602 Merge branch 'refactor-get-account-by-token' into refactor/get-account-usage 2024-09-26 16:39:52 +03:00
bcmmbaga
871595d15f Merge branch 'main' into refactor-get-account-by-token
# Conflicts:
#	management/server/sql_store.go
2024-09-26 16:39:17 +03:00
bcmmbaga
30253b0565 Merge branch 'refactor-get-account-by-token' into refactor/get-account-usage 2024-09-26 16:34:36 +03:00
bcmmbaga
dc82c2d1ce fix add missing policy source posture checks
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-26 16:34:19 +03:00
bcmmbaga
3b4bcdf5a4 refactor posture checks save and deletion
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-26 16:28:49 +03:00
bcmmbaga
87c8430e99 add store policy save and method
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-25 22:47:54 +03:00
bcmmbaga
c384874d7d fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-25 15:04:57 +03:00
bcmmbaga
b815393180 fix lint
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-25 13:02:08 +03:00
bcmmbaga
41b212f610 Refactor store
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-25 12:53:20 +03:00
bcmmbaga
16174f0478 Refactor route, setupkey, nameserver and dns to get record(s) from store
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-25 12:52:42 +03:00
bcmmbaga
d14b855670 Refactor user permissions and retrieves PAT
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-24 22:57:04 +03:00
bcmmbaga
eab85644cd Refactor retrieval of policy and posture checks
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-24 21:57:33 +03:00
bcmmbaga
7561706627 add GetGroupByID from store and refactor
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-24 19:55:33 +03:00
bcmmbaga
1ffe89d20d add GetGroupByName from store
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-24 16:36:57 +03:00
bcmmbaga
28840383e1 refactor
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-24 13:30:13 +03:00
bcmmbaga
d9f612d623 remove locks
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-23 20:12:57 +03:00
bcmmbaga
7601a17150 fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-22 23:44:10 +03:00
bcmmbaga
8f98adddf6 refactor handlers to use GetAccountIDFromToken
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-22 15:14:31 +03:00
bcmmbaga
26dd045da5 Merge branch 'main' into refactor-get-account-by-token 2024-09-20 14:08:09 +03:00
bcmmbaga
4d9bb7ea35 refactor getAccountWithAuthorizationClaims to return account id
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-20 14:07:44 +03:00
bcmmbaga
9631cb4fb3 fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-18 19:05:21 +03:00
bcmmbaga
8f9c54f6c2 remove GetUserByID from account manager
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-18 17:03:04 +03:00
bcmmbaga
f60a4234b1 revert handles change
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-18 16:40:47 +03:00
bcmmbaga
021fc8f33e fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-18 16:11:20 +03:00
bcmmbaga
a4c4158bcf Merge branch 'main' into refactor-get-account-by-token 2024-09-18 16:03:55 +03:00
bcmmbaga
720d36a290 refactor getAccountWithAuthorizationClaims
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-18 15:55:52 +03:00
bcmmbaga
ccab3b427f refactor getAccountFromToken
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-18 14:24:39 +03:00
bcmmbaga
e5d55d3c10 refactor handlers to get account when necessary
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-17 23:15:54 +03:00
bcmmbaga
3cf1b02f31 refactor jwt groups extractor
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-17 18:18:59 +03:00
bcmmbaga
258b30cf48 refactor access control middleware and user access by JWT groups
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-16 13:33:36 +03:00
49 changed files with 4701 additions and 3535 deletions

View File

@@ -38,7 +38,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@@ -171,7 +170,7 @@ type Engine struct {
relayManager *relayClient.Manager relayManager *relayClient.Manager
stateManager *statemanager.Manager stateManager *statemanager.Manager
srWatcher *guard.SRWatcher srWatcher *guard.SRWatcher
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer

View File

@@ -28,9 +28,12 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 21:28:24.830195+02:00','','',0,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 21:28:24.830195+02:00','','',0,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]');
INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL);
INSERT INTO installations VALUES(1,''); INSERT INTO installations VALUES(1,'');
COMMIT; COMMIT;

View File

@@ -136,6 +136,7 @@ func ParseNameServerURL(nsURL string) (NameServer, error) {
func (g *NameServerGroup) Copy() *NameServerGroup { func (g *NameServerGroup) Copy() *NameServerGroup {
nsGroup := &NameServerGroup{ nsGroup := &NameServerGroup{
ID: g.ID, ID: g.ID,
AccountID: g.AccountID,
Name: g.Name, Name: g.Name,
Description: g.Description, Description: g.Description,
NameServers: make([]NameServer, len(g.NameServers)), NameServers: make([]NameServer, len(g.NameServers)),
@@ -156,6 +157,7 @@ func (g *NameServerGroup) Copy() *NameServerGroup {
// IsEqual compares one nameserver group with the other // IsEqual compares one nameserver group with the other
func (g *NameServerGroup) IsEqual(other *NameServerGroup) bool { func (g *NameServerGroup) IsEqual(other *NameServerGroup) bool {
return other.ID == g.ID && return other.ID == g.ID &&
other.AccountID == g.AccountID &&
other.Name == g.Name && other.Name == g.Name &&
other.Description == g.Description && other.Description == g.Description &&
other.Primary == g.Primary && other.Primary == g.Primary &&

File diff suppressed because it is too large Load Diff

View File

@@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/rs/xid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -397,7 +398,14 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
} }
for _, testCase := range tt { for _, testCase := range tt {
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io") store := newStore(t)
err := newAccountWithId(context.Background(), store, "account-1", userID, "netbird.io")
require.NoError(t, err, "failed to create account")
account, err := store.GetAccount(context.Background(), "account-1")
require.NoError(t, err, "failed to get account")
account.UpdateSettings(&testCase.accountSettings) account.UpdateSettings(&testCase.accountSettings)
account.Network = network account.Network = network
account.Peers = testCase.peers account.Peers = testCase.peers
@@ -415,6 +423,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil) networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
store.Close(context.Background())
} }
} }
@@ -422,7 +432,15 @@ func TestNewAccount(t *testing.T) {
domain := "netbird.io" domain := "netbird.io"
userId := "account_creator" userId := "account_creator"
accountID := "account_id" accountID := "account_id"
account := newAccountWithId(context.Background(), accountID, userId, domain)
store := newStore(t)
defer store.Close(context.Background())
err := newAccountWithId(context.Background(), store, accountID, userId, domain)
require.NoError(t, err, "failed to create account")
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "failed to get account")
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId}) verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
} }
@@ -433,16 +451,16 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
return return
} }
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if account == nil { if accountID == "" {
t.Fatalf("expected to create an account for a user %s", userID) t.Fatalf("expected to create an account for a user %s", userID)
return return
} }
account, err = manager.Store.GetAccountByUser(context.Background(), userID) account, err := manager.Store.GetAccountByUser(context.Background(), userID)
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID) t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID)
return return
@@ -665,15 +683,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id" userId := "user-id"
domain := "test.domain" domain := "test.domain"
_ = newAccountWithId(context.Background(), "", userId, domain)
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountIDByUserID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount, err := manager.Store.GetAccount(context.Background(), accountID) initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed") require.NoError(t, err, "get init account failed")
@@ -689,44 +704,53 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID) accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "get account failed") require.NoError(t, err, "failed to get account groups")
require.Len(t, account.Groups, 1, "only ALL group should exists") require.Len(t, accountGroups, 1, "only ALL group should exists")
}) })
t.Run("JWT groups enabled without claim name", func(t *testing.T) { t.Run("JWT groups enabled without claim name", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true initAccount.Settings.JWTGroupsEnabled = true
err := manager.Store.SaveAccount(context.Background(), initAccount) _, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings)
require.NoError(t, err, "save account failed") require.NoError(t, err, "failed to update account settings")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
accountIDs, err := manager.Store.GetAllAccountIDs(context.Background(), LockingStrengthShare)
require.NoError(t, err, "failed to get account ids")
require.Len(t, accountIDs, 1, "only one account should exist")
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID) accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "get account failed") require.NoError(t, err, "failed to get account groups")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") require.Len(t, accountGroups, 1, "if group claim is not set no group added from JWT")
}) })
t.Run("JWT groups enabled", func(t *testing.T) { t.Run("JWT groups enabled", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true initAccount.Settings.JWTGroupsEnabled = true
initAccount.Settings.JWTGroupsClaimName = "idp-groups" initAccount.Settings.JWTGroupsClaimName = "idp-groups"
err := manager.Store.SaveAccount(context.Background(), initAccount) _, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings)
require.NoError(t, err, "save account failed") require.NoError(t, err, "failed to update account settings")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
accountIDs, err := manager.Store.GetAllAccountIDs(context.Background(), LockingStrengthShare)
require.NoError(t, err, "failed to get account ids")
require.Len(t, accountIDs, 1, "only one account should exist")
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID) exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "get account failed") require.NoError(t, err, "failed to check account existence")
require.True(t, exists, "account should exist")
require.Len(t, account.Groups, 3, "groups should be added to the account") accountGroups, err := manager.GetAllGroups(context.Background(), accountID, userId)
require.NoError(t, err, "failed to get account groups")
require.Len(t, accountGroups, 3, "groups should be added to the account")
groupsByNames := map[string]*group.Group{} groupsByNames := map[string]*group.Group{}
for _, g := range account.Groups { for _, g := range accountGroups {
groupsByNames[g.Name] = g groupsByNames[g.Name] = g
} }
@@ -744,60 +768,53 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
func TestAccountManager_GetAccountFromPAT(t *testing.T) { func TestAccountManager_GetAccountFromPAT(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(context.Background(), "account_id", "testuser", "") err := newAccountWithId(context.Background(), store, "account_id", "testuser", "")
require.NoError(t, err, "failed to create account")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token)) hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &User{
Id: "someUser", userPAT := &PersonalAccessToken{
PATs: map[string]*PersonalAccessToken{ ID: "tokenId",
"tokenId": { UserID: "testuser",
ID: "tokenId", HashedToken: encodedHashedToken,
HashedToken: encodedHashedToken, CreatedAt: time.Now().UTC(),
},
},
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
} }
err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT)
require.NoError(t, err, "failed to save PAT")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
} }
account, user, pat, err := am.GetAccountFromPAT(context.Background(), token) user, pat, _, _, err := am.GetAccountInfoFromPAT(context.Background(), token)
if err != nil { if err != nil {
t.Fatalf("Error when getting Account from PAT: %s", err) t.Fatalf("Error when getting Account from PAT: %s", err)
} }
assert.Equal(t, "account_id", account.Id) assert.Equal(t, "account_id", user.AccountID)
assert.Equal(t, "someUser", user.Id) assert.Equal(t, "testuser", user.Id)
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat) assert.Equal(t, userPAT, pat)
} }
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(context.Background(), "account_id", "testuser", "") err := newAccountWithId(context.Background(), store, "account_id", "testuser", "")
require.NoError(t, err, "failed to create account")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token)) hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &User{
Id: "someUser", userPAT := &PersonalAccessToken{
PATs: map[string]*PersonalAccessToken{ ID: "tokenId",
"tokenId": { UserID: "someUser",
ID: "tokenId", HashedToken: encodedHashedToken,
HashedToken: encodedHashedToken, LastUsed: time.Time{},
LastUsed: time.Time{},
},
},
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
} }
err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT)
require.NoError(t, err, "failed to save PAT")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -808,11 +825,10 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
t.Fatalf("Error when marking PAT used: %s", err) t.Fatalf("Error when marking PAT used: %s", err)
} }
account, err = am.Store.GetAccount(context.Background(), "account_id") userPAT, err = store.GetPATByID(context.Background(), LockingStrengthShare, userPAT.UserID, userPAT.ID)
if err != nil { require.NoError(t, err, "failed to get PAT")
t.Fatalf("Error when getting account: %s", err)
} assert.True(t, !userPAT.LastUsed.IsZero())
assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero())
} }
func TestAccountManager_PrivateAccount(t *testing.T) { func TestAccountManager_PrivateAccount(t *testing.T) {
@@ -823,15 +839,15 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
} }
userId := "test_user" userId := "test_user"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if account == nil { if accountID == "" {
t.Fatalf("expected to create an account for a user %s", userId) t.Fatalf("expected to create an account for a user %s", userId)
} }
account, err = manager.Store.GetAccountByUser(context.Background(), userId) account, err := manager.Store.GetAccountByUser(context.Background(), userId)
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
} }
@@ -850,32 +866,22 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
userId := "test_user" userId := "test_user"
domain := "hotmail.com" domain := "hotmail.com"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain) accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain)
if err != nil { require.NoError(t, err, "failed to get or create account by user")
t.Fatal(err) require.NotEmptyf(t, accountID, "expected to create an account for a user %s", userId)
}
if account == nil {
t.Fatalf("expected to create an account for a user %s", userId)
}
if account != nil && account.Domain != domain { accDomain, _, err := manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID)
t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain) require.NoError(t, err, "failed to get account domain and category")
} require.Equal(t, domain, accDomain, "expected account domain to match")
domain = "gmail.com" domain = "gmail.com"
account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain) accountID, err = manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain)
if err != nil { require.NoError(t, err, "failed to get or create account by user")
t.Fatalf("got the following error while retrieving existing acc: %v", err)
}
if account == nil { accDomain, _, err = manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID)
t.Fatalf("expected to get an account for a user %s", userId) require.NoError(t, err, "failed to get account domain and category")
} require.Equal(t, domain, accDomain, "expected account domain to match")
if account != nil && account.Domain != domain {
t.Errorf("updating domain. expected %s got %s", domain, account.Domain)
}
} }
func TestAccountManager_GetAccountByUserID(t *testing.T) { func TestAccountManager_GetAccountByUserID(t *testing.T) {
@@ -907,12 +913,11 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
} }
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) { func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) {
account := newAccountWithId(context.Background(), accountID, userID, domain) err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return account, nil return am.Store.GetAccount(context.Background(), accountID)
} }
func TestAccountManager_GetAccount(t *testing.T) { func TestAccountManager_GetAccount(t *testing.T) {
@@ -1055,23 +1060,18 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return return
} }
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "netbird.cloud")
if err != nil { require.NoError(t, err, "failed to get or create account by user")
t.Fatal(err)
}
serial := account.Network.CurrentSerial() // should be 0 network, err := manager.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account network")
if account.Network.Serial != 0 { serial := network.CurrentSerial() // should be 0
t.Errorf("expecting account network to have an initial Serial=0") require.Equal(t, 0, int(serial), "expected account network to have an initial Serial=0")
return
}
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { require.NoError(t, err, "failed to generate private key")
t.Fatal(err)
return
}
expectedPeerKey := key.PublicKey().String() expectedPeerKey := key.PublicKey().String()
expectedUserID := userID expectedUserID := userID
@@ -1079,16 +1079,10 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
Key: expectedPeerKey, Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}) })
if err != nil { require.NoError(t, err, "failed to add peer")
t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy)
return
}
account, err = manager.Store.GetAccount(context.Background(), account.Id) account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil { require.NoError(t, err, "failed to get account")
t.Fatal(err)
return
}
if peer.Key != expectedPeerKey { if peer.Key != expectedPeerKey {
t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key) t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key)
@@ -1131,10 +1125,12 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
} }
policy := Policy{ policy := Policy{
ID: "policy", ID: xid.New().String(),
Enabled: true, AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(),
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupA"}, Destinations: []string{"groupA"},
@@ -1212,10 +1208,15 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
policyID := xid.New().String()
policy := Policy{ policy := Policy{
Enabled: true, ID: policyID,
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: "rule",
PolicyID: policyID,
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupA"}, Destinations: []string{"groupA"},
@@ -1249,19 +1250,25 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
manager, account, peer1, _, peer3 := setupNetworkMapTest(t) manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
group := group.Group{ group := group.Group{
ID: "groupA", ID: "groupA",
Name: "GroupA", AccountID: account.Id,
Peers: []string{peer1.ID, peer3.ID}, Name: "GroupA",
Peers: []string{peer1.ID, peer3.ID},
} }
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err) t.Errorf("save group: %v", err)
return return
} }
policyID := xid.New().String()
policy := Policy{ policy := Policy{
Enabled: true, ID: policyID,
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: "rule",
PolicyID: policyID,
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupA"}, Destinations: []string{"groupA"},
@@ -1302,19 +1309,24 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
group := group.Group{ group := group.Group{
ID: "groupA", ID: "groupA",
Name: "GroupA", AccountID: account.Id,
Peers: []string{peer1.ID, peer2.ID, peer3.ID}, Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
} }
err := manager.SaveGroup(context.Background(), account.Id, userID, &group)
require.NoError(t, err, "failed to save group")
policyID := xid.New().String()
policy := Policy{ policy := Policy{
Enabled: true, ID: policyID,
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: "rule",
PolicyID: policyID,
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupA"}, Destinations: []string{"groupA"},
@@ -1324,6 +1336,9 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
}, },
} }
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err) t.Errorf("delete default rule: %v", err)
return return
@@ -1352,7 +1367,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
return return
} }
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { if err := manager.DeleteGroup(context.Background(), account.Id, userID, group.ID); err != nil {
t.Errorf("delete group: %v", err) t.Errorf("delete group: %v", err)
return return
} }
@@ -1748,18 +1763,9 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(2) wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{ manager.peerLoginExpiry = &MockScheduler{
@@ -1774,11 +1780,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
// disable expiration first // disable expiration first
update := peer.Copy() update := peer.Copy()
update.LoginExpirationEnabled = false update.LoginExpirationEnabled = false
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update) _, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
require.NoError(t, err, "unable to update peer") require.NoError(t, err, "unable to update peer")
// enabling expiration should trigger the routine // enabling expiration should trigger the routine
update.LoginExpirationEnabled = true update.LoginExpirationEnabled = true
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update) _, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
require.NoError(t, err, "unable to update peer") require.NoError(t, err, "unable to update peer")
failed := waitTimeout(wg, time.Second) failed := waitTimeout(wg, time.Second)
@@ -1802,10 +1808,13 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour, settings, err := manager.GetAccountSettings(context.Background(), accountID, userID)
PeerLoginExpirationEnabled: true, require.NoError(t, err, "failed to get account settings")
})
settings.PeerLoginExpirationEnabled = true
settings.PeerLoginExpiration = time.Hour
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
@@ -1822,11 +1831,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger // when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second) failed := waitTimeout(wg, time.Second)
@@ -1854,10 +1860,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
@@ -1871,10 +1874,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
}, },
} }
// enabling PeerLoginExpirationEnabled should trigger the expiration job // enabling PeerLoginExpirationEnabled should trigger the expiration job
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ settings, err := manager.GetAccountSettings(context.Background(), accountID, userID)
PeerLoginExpiration: time.Hour, require.NoError(t, err, "failed to get account settings")
PeerLoginExpirationEnabled: true,
}) settings.PeerLoginExpirationEnabled = true
settings.PeerLoginExpiration = time.Hour
settings, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
failed := waitTimeout(wg, time.Second) failed := waitTimeout(wg, time.Second)
@@ -1884,10 +1889,8 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
wg.Add(1) wg.Add(1)
// disabling PeerLoginExpirationEnabled should trigger cancel // disabling PeerLoginExpirationEnabled should trigger cancel
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ settings.PeerLoginExpirationEnabled = false
PeerLoginExpiration: time.Hour, _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
PeerLoginExpirationEnabled: false,
})
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
failed = waitTimeout(wg, time.Second) failed = waitTimeout(wg, time.Second)
if failed { if failed {
@@ -1902,30 +1905,29 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
settings.PeerLoginExpirationEnabled = false
settings.PeerLoginExpiration = time.Hour
updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
require.NoError(t, err, "expecting to update account settings successfully but got error")
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
settings, err = manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings")
assert.False(t, settings.PeerLoginExpirationEnabled) assert.False(t, settings.PeerLoginExpirationEnabled)
assert.Equal(t, settings.PeerLoginExpiration, time.Hour) assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ settings.PeerLoginExpirationEnabled = false
PeerLoginExpiration: time.Second, settings.PeerLoginExpiration = time.Second
PeerLoginExpirationEnabled: false, _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ settings.PeerLoginExpirationEnabled = false
PeerLoginExpiration: time.Hour * 24 * 181, settings.PeerLoginExpiration = time.Hour * 24 * 181
PeerLoginExpirationEnabled: false, _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days") require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
} }
@@ -2606,7 +2608,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0) assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })
@@ -2626,7 +2628,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1) assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })
@@ -2665,7 +2667,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.syncJWTGroups(context.Background(), "accountID", claims) err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID") groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added") assert.Len(t, groups, 3, "new group3 should be added")

View File

@@ -6,6 +6,7 @@ import (
"strconv" "strconv"
"sync" "sync"
nbgroup "github.com/netbirdio/netbird/management/server/group"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings") return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
} }
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
@@ -94,64 +99,105 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
// SaveDNSSettings validates a user role and updates the account's DNS settings // SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
}
if dnsSettingsToSave == nil { if dnsSettingsToSave == nil {
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
} }
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return status.NewAdminPermissionError()
}
oldSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return err
}
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
if len(dnsSettingsToSave.DisabledManagementGroups) != 0 { if len(dnsSettingsToSave.DisabledManagementGroups) != 0 {
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups) err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, groups)
if err != nil { if err != nil {
return err return err
} }
} }
oldSettings := account.DNSSettings.Copy()
account.DNSSettings = dnsSettingsToSave.Copy()
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
account.Network.IncSerial() updateAccountPeers, err := am.areDNSSettingChangesAffectPeers(ctx, accountID, addedGroups, removedGroups)
if err = am.Store.SaveAccount(ctx, account); err != nil { if err != nil {
return fmt.Errorf("failed to check if dns settings changes affect peers: %w", err)
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if err = transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave); err != nil {
return fmt.Errorf("failed to update dns settings: %w", err)
}
return nil
})
if err != nil {
return err return err
} }
groupMap := make(map[string]*nbgroup.Group, len(groups))
for _, g := range groups {
groupMap[g.ID] = g
}
for _, id := range addedGroups { for _, id := range addedGroups {
group := account.GetGroup(id) group, ok := groupMap[id]
meta := map[string]any{"group": group.Name, "group_id": group.ID} if ok {
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
}
} }
for _, id := range removedGroups { for _, id := range removedGroups {
group := account.GetGroup(id) group, ok := groupMap[id]
meta := map[string]any{"group": group.Name, "group_id": group.ID} if ok {
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
}
} }
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { if updateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
} }
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
func (am *DefaultAccountManager) areDNSSettingChangesAffectPeers(ctx context.Context, accountID string, addedGroups, removedGroups []string) (bool, error) {
hasPeers, err := am.anyGroupHasPeers(ctx, accountID, addedGroups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return am.anyGroupHasPeers(ctx, accountID, removedGroups)
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig { func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{ protoUpdate := &proto.DNSConfig{

View File

@@ -39,12 +39,12 @@ func TestGetDNSSettings(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestDNSAccount(t, am) accountID, err := initTestDNSAccount(t, am)
if err != nil { if err != nil {
t.Fatal("failed to init testing account") t.Fatal("failed to init testing account")
} }
dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID) dnsSettings, err := am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID)
if err != nil { if err != nil {
t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
} }
@@ -53,16 +53,12 @@ func TestGetDNSSettings(t *testing.T) {
t.Fatal("DNS settings for new accounts shouldn't return nil") t.Fatal("DNS settings for new accounts shouldn't return nil")
} }
account.DNSSettings = DNSSettings{ err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &DNSSettings{
DisabledManagementGroups: []string{group1ID}, DisabledManagementGroups: []string{group1ID},
} })
require.NoError(t, err, "failed to update DNS settings")
err = am.Store.SaveAccount(context.Background(), account) dnsSettings, err = am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID)
if err != nil {
t.Error("failed to save testing account with new DNS settings")
}
dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
if err != nil { if err != nil {
t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
} }
@@ -71,7 +67,7 @@ func TestGetDNSSettings(t *testing.T) {
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups) t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
} }
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID) _, err = am.GetDNSSettings(context.Background(), accountID, dnsRegularUserID)
if err == nil { if err == nil {
t.Errorf("An error should be returned when getting the DNS settings with a regular user") t.Errorf("An error should be returned when getting the DNS settings with a regular user")
} }
@@ -126,12 +122,12 @@ func TestSaveDNSSettings(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestDNSAccount(t, am) accountID, err := initTestDNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings) err = am.SaveDNSSettings(context.Background(), accountID, testCase.userID, testCase.inputSettings)
if err != nil { if err != nil {
if testCase.shouldFail { if testCase.shouldFail {
return return
@@ -139,7 +135,7 @@ func TestSaveDNSSettings(t *testing.T) {
t.Error(err) t.Error(err)
} }
updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id) updatedAccount, err := am.Store.GetAccount(context.Background(), accountID)
if err != nil { if err != nil {
t.Errorf("should be able to retrieve updated account, got err: %s", err) t.Errorf("should be able to retrieve updated account, got err: %s", err)
} }
@@ -158,17 +154,17 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestDNSAccount(t, am) accountID, err := initTestDNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
peer1, err := account.FindPeerByPubKey(dnsPeer1Key) peer1, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer1Key)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
peer2, err := account.FindPeerByPubKey(dnsPeer2Key) peer2, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer2Key)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
@@ -179,11 +175,13 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled") require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group") require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group")
dnsSettings := account.DNSSettings.Copy() accountDNSSettings, err := am.Store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account DNS settings")
dnsSettings := accountDNSSettings.Copy()
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID) dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
account.DNSSettings = dnsSettings err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &dnsSettings)
err = am.Store.SaveAccount(context.Background(), account) require.NoError(t, err, "failed to update DNS settings")
require.NoError(t, err)
updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID) updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
require.NoError(t, err) require.NoError(t, err)
@@ -222,7 +220,7 @@ func createDNSStore(t *testing.T) (Store, error) {
return store, nil return store, nil
} }
func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
t.Helper() t.Helper()
peer1 := &nbpeer.Peer{ peer1 := &nbpeer.Peer{
Key: dnsPeer1Key, Key: dnsPeer1Key,
@@ -257,64 +255,65 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
domain := "example.com" domain := "example.com"
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain) err := newAccountWithId(context.Background(), am.Store, dnsAccountID, dnsAdminUserID, domain)
if err != nil {
account.Users[dnsRegularUserID] = &User{ return "", err
Id: dnsRegularUserID,
Role: UserRoleUser,
} }
err := am.Store.SaveAccount(context.Background(), account) err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: dnsRegularUserID,
AccountID: dnsAccountID,
Role: UserRoleUser,
})
if err != nil { if err != nil {
return nil, err return "", err
} }
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1) savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1)
if err != nil { if err != nil {
return nil, err return "", err
} }
_, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2) _, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2)
if err != nil { if err != nil {
return nil, err return "", err
} }
account, err = am.Store.GetAccount(context.Background(), account.Id) peer1, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer1.Key)
if err != nil { if err != nil {
return nil, err return "", err
} }
peer1, err = account.FindPeerByPubKey(peer1.Key) _, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer2.Key)
if err != nil { if err != nil {
return nil, err return "", err
} }
_, err = account.FindPeerByPubKey(peer2.Key) err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{
{
ID: dnsGroup1ID,
AccountID: dnsAccountID,
Peers: []string{peer1.ID},
Name: dnsGroup1ID,
},
{
ID: dnsGroup2ID,
AccountID: dnsAccountID,
Name: dnsGroup2ID,
},
})
if err != nil { if err != nil {
return nil, err return "", err
} }
newGroup1 := &group.Group{ allGroup, err := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, dnsAccountID, "All")
ID: dnsGroup1ID,
Peers: []string{peer1.ID},
Name: dnsGroup1ID,
}
newGroup2 := &group.Group{
ID: dnsGroup2ID,
Name: dnsGroup2ID,
}
account.Groups[newGroup1.ID] = newGroup1
account.Groups[newGroup2.ID] = newGroup2
allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
return nil, err return "", err
} }
account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{ err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &dns.NameServerGroup{
ID: dnsNSGroup1, ID: dnsNSGroup1,
Name: "ns-group-1", AccountID: dnsAccountID,
Name: "ns-group-1",
NameServers: []dns.NameServer{{ NameServers: []dns.NameServer{{
IP: netip.MustParseAddr(savedPeer1.IP.String()), IP: netip.MustParseAddr(savedPeer1.IP.String()),
NSType: dns.UDPNameServerType, NSType: dns.UDPNameServerType,
@@ -323,14 +322,12 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
Primary: true, Primary: true,
Enabled: true, Enabled: true,
Groups: []string{allGroup.ID}, Groups: []string{allGroup.ID},
} })
err = am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return "", err
} }
return am.Store.GetAccount(context.Background(), account.Id) return dnsAccountID, nil
} }
func generateTestData(size int) nbdns.Config { func generateTestData(size int) nbdns.Config {

View File

@@ -20,10 +20,10 @@ var (
) )
type ephemeralPeer struct { type ephemeralPeer struct {
id string id string
account *Account accountID string
deadline time.Time deadline time.Time
next *ephemeralPeer next *ephemeralPeer
} }
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it // todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
@@ -104,12 +104,6 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID) log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err)
return
}
e.peersLock.Lock() e.peersLock.Lock()
defer e.peersLock.Unlock() defer e.peersLock.Unlock()
@@ -117,7 +111,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
return return
} }
e.addPeer(peer.ID, a, newDeadLine()) e.addPeer(peer.AccountID, peer.ID, newDeadLine())
if e.timer == nil { if e.timer == nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
e.cleanup(ctx) e.cleanup(ctx)
@@ -126,17 +120,21 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
} }
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
accounts := e.store.GetAllAccounts(context.Background()) peers, err := e.store.GetAllEphemeralPeers(ctx, LockingStrengthShare)
if err != nil {
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
return
}
t := newDeadLine() t := newDeadLine()
count := 0 count := 0
for _, a := range accounts { for _, p := range peers {
for id, p := range a.Peers { if p.Ephemeral {
if p.Ephemeral { count++
count++ e.addPeer(p.AccountID, p.ID, t)
e.addPeer(id, a, t)
}
} }
} }
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count) log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count)
} }
@@ -170,18 +168,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
for id, p := range deletePeers { for id, p := range deletePeers {
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator) err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
} }
} }
} }
func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) { func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
ep := &ephemeralPeer{ ep := &ephemeralPeer{
id: id, id: peerID,
account: account, accountID: accountID,
deadline: deadline, deadline: deadline,
} }
if e.headPeer == nil { if e.headPeer == nil {

View File

@@ -7,25 +7,12 @@ import (
"time" "time"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/stretchr/testify/require"
) )
type MockStore struct { type MockStore struct {
Store Store
account *Account accountID string
}
func (s *MockStore) GetAllAccounts(_ context.Context) []*Account {
return []*Account{s.account}
}
func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) {
_, ok := s.account.Peers[peerId]
if ok {
return s.account, nil
}
return nil, status.NewPeerNotFoundError(peerId)
} }
type MocAccountManager struct { type MocAccountManager struct {
@@ -33,9 +20,8 @@ type MocAccountManager struct {
store *MockStore store *MockStore
} }
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error { func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, _ string) error {
delete(a.store.account.Peers, peerID) return a.store.DeletePeer(context.Background(), LockingStrengthUpdate, accountID, peerID)
return nil //nolint:nil
} }
func TestNewManager(t *testing.T) { func TestNewManager(t *testing.T) {
@@ -44,23 +30,26 @@ func TestNewManager(t *testing.T) {
return startTime return startTime
} }
store := &MockStore{} store := &MockStore{
Store: newStore(t),
}
am := MocAccountManager{ am := MocAccountManager{
store: store, store: store,
} }
numberOfPeers := 5 numberOfPeers := 5
numberOfEphemeralPeers := 3 numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers) err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
require.NoError(t, err, "failed to seed peers")
mgr := NewEphemeralManager(store, am) mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers(context.Background()) mgr.loadEphemeralPeers(context.Background())
startTime = startTime.Add(ephemeralLifeTime + 1) startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup(context.Background()) mgr.cleanup(context.Background())
if len(store.account.Peers) != numberOfPeers { peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers)) require.NoError(t, err, "failed to get account peers")
} require.Equal(t, numberOfPeers, len(peers), "failed to cleanup ephemeral peers")
} }
func TestNewManagerPeerConnected(t *testing.T) { func TestNewManagerPeerConnected(t *testing.T) {
@@ -69,26 +58,32 @@ func TestNewManagerPeerConnected(t *testing.T) {
return startTime return startTime
} }
store := &MockStore{} store := &MockStore{
Store: newStore(t),
}
am := MocAccountManager{ am := MocAccountManager{
store: store, store: store,
} }
numberOfPeers := 5 numberOfPeers := 5
numberOfEphemeralPeers := 3 numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers) err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
require.NoError(t, err, "failed to seed peers")
mgr := NewEphemeralManager(store, am) mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers(context.Background()) mgr.loadEphemeralPeers(context.Background())
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0")
require.NoError(t, err, "failed to get peer")
mgr.OnPeerConnected(context.Background(), peer)
startTime = startTime.Add(ephemeralLifeTime + 1) startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup(context.Background()) mgr.cleanup(context.Background())
expected := numberOfPeers + 1 peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
if len(store.account.Peers) != expected { require.NoError(t, err, "failed to get account peers")
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers)) require.Equal(t, numberOfPeers+1, len(peers), "failed to cleanup ephemeral peers")
}
} }
func TestNewManagerPeerDisconnected(t *testing.T) { func TestNewManagerPeerDisconnected(t *testing.T) {
@@ -97,50 +92,73 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
return startTime return startTime
} }
store := &MockStore{} store := &MockStore{
Store: newStore(t),
}
am := MocAccountManager{ am := MocAccountManager{
store: store, store: store,
} }
numberOfPeers := 5 numberOfPeers := 5
numberOfEphemeralPeers := 3 numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers) err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
require.NoError(t, err, "failed to seed peers")
mgr := NewEphemeralManager(store, am) mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers(context.Background()) mgr.loadEphemeralPeers(context.Background())
for _, v := range store.account.Peers {
mgr.OnPeerConnected(context.Background(), v)
peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
require.NoError(t, err, "failed to get account peers")
for _, v := range peers {
mgr.OnPeerConnected(context.Background(), v)
} }
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0")
require.NoError(t, err, "failed to get peer")
mgr.OnPeerDisconnected(context.Background(), peer)
startTime = startTime.Add(ephemeralLifeTime + 1) startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup(context.Background()) mgr.cleanup(context.Background())
peers, err = store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
require.NoError(t, err, "failed to get account peers")
expected := numberOfPeers + numberOfEphemeralPeers - 1 expected := numberOfPeers + numberOfEphemeralPeers - 1
if len(store.account.Peers) != expected { require.Equal(t, expected, len(peers), "failed to cleanup ephemeral peers")
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
}
} }
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) { func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) error {
store.account = newAccountWithId(context.Background(), "my account", "", "") accountID := "my account"
err := newAccountWithId(context.Background(), store, accountID, "", "")
if err != nil {
return err
}
store.accountID = accountID
for i := 0; i < numberOfPeers; i++ { for i := 0; i < numberOfPeers; i++ {
peerId := fmt.Sprintf("peer_%d", i) peerId := fmt.Sprintf("peer_%d", i)
p := &nbpeer.Peer{ p := &nbpeer.Peer{
ID: peerId, ID: peerId,
AccountID: accountID,
Ephemeral: false, Ephemeral: false,
} }
store.account.Peers[p.ID] = p err = store.AddPeerToAccount(context.Background(), p)
if err != nil {
return err
}
} }
for i := 0; i < numberOfEphemeralPeers; i++ { for i := 0; i < numberOfEphemeralPeers; i++ {
peerId := fmt.Sprintf("ephemeral_peer_%d", i) peerId := fmt.Sprintf("ephemeral_peer_%d", i)
p := &nbpeer.Peer{ p := &nbpeer.Peer{
ID: peerId, ID: peerId,
AccountID: accountID,
Ephemeral: true, Ephemeral: true,
} }
store.account.Peers[p.ID] = p err = store.AddPeerToAccount(context.Background(), p)
if err != nil {
return err
}
} }
return nil
} }

View File

@@ -37,8 +37,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
return err return err
} }
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID { if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "groups are blocked for users") return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() && settings.RegularUsersViewBlocked {
return status.NewAdminPermissionError()
} }
return nil return nil
@@ -50,7 +54,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
return nil, err return nil, err
} }
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
} }
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
@@ -59,31 +63,34 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
return nil, err return nil, err
} }
return am.Store.GetAccountGroups(ctx, accountID) return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
} }
// SaveGroup object of the peers // SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error { func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup}) return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
} }
// SaveGroups adds new groups to the account. // SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
var eventsToStore []func() if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
var (
eventsToStore []func()
groupsToSave []*nbgroup.Group
)
for _, newGroup := range newGroups { for _, newGroup := range newGroups {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
@@ -91,7 +98,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
} }
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name) existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
if err != nil { if err != nil {
s, ok := status.FromError(err) s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound { if !ok || s.ErrorType != status.NotFound {
@@ -109,15 +116,15 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
} }
for _, peerID := range newGroup.Peers { for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil { if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID); err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
} }
} }
oldGroup := account.Groups[newGroup.ID] newGroup.AccountID = accountID
account.Groups[newGroup.ID] = newGroup groupsToSave = append(groupsToSave, newGroup)
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account) events := am.prepareGroupEvents(ctx, userID, accountID, newGroup)
eventsToStore = append(eventsToStore, events...) eventsToStore = append(eventsToStore, events...)
} }
@@ -126,30 +133,45 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
newGroupIDs = append(newGroupIDs, newGroup.ID) newGroupIDs = append(newGroupIDs, newGroup.ID)
} }
account.Network.IncSerial() updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, newGroupIDs)
if err = am.Store.SaveAccount(ctx, account); err != nil { if err != nil {
return err return err
} }
if areGroupChangesAffectPeers(account, newGroupIDs) { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
am.updateAccountPeers(ctx, account) if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil {
return fmt.Errorf("failed to save groups: %w", err)
}
return nil
})
if err != nil {
return err
} }
for _, storeEvent := range eventsToStore { for _, storeEvent := range eventsToStore {
storeEvent() storeEvent()
} }
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil return nil
} }
// prepareGroupEvents prepares a list of event functions to be stored. // prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() { func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup *nbgroup.Group) []func() {
var eventsToStore []func() var eventsToStore []func()
addedPeers := make([]string, 0) addedPeers := make([]string, 0)
removedPeers := make([]string, 0) removedPeers := make([]string, 0)
if oldGroup != nil { oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
if err == nil && oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers) addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else { } else {
@@ -159,12 +181,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
}) })
} }
for _, p := range addedPeers { for _, peerID := range addedPeers {
peer := account.Peers[p] peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if peer == nil { if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
continue continue
} }
peerCopy := peer // copy to avoid closure issues peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
@@ -175,12 +198,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
}) })
} }
for _, p := range removedPeers { for _, peerID := range removedPeers {
peer := account.Peers[p] peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if peer == nil { if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
continue continue
} }
peerCopy := peer // copy to avoid closure issues peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
@@ -210,119 +234,108 @@ func difference(a, b []string) []string {
} }
// DeleteGroup object of the peers. // DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID] if user.AccountID != accountID {
if !ok { return status.NewUserNotPartOfAccountError()
return nil
} }
allGroup, err := account.GetGroupAll() group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil { if err != nil {
return err return err
} }
if allGroup.ID == groupID { if group.Name == "All" {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
} }
if err = validateDeleteGroup(account, group, userId); err != nil { if err = am.validateDeleteGroup(ctx, group, userID); err != nil {
return err
}
delete(account.Groups, groupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta()) err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID); err != nil {
return fmt.Errorf("failed to delete group: %w", err)
}
return nil
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta())
return nil return nil
} }
// DeleteGroups deletes groups from an account. // DeleteGroups deletes groups from an account.
// Note: This function does not acquire the global lock. func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
// It is the caller's responsibility to ensure proper locking is in place before invoking this method. user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
//
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil { if err != nil {
return err return err
} }
var allErrors error if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
var (
allErrors error
groupIDsToDelete []string
deletedGroups []*nbgroup.Group
)
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
group, ok := account.Groups[groupID] group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if !ok { if err != nil {
continue continue
} }
if err := validateDeleteGroup(account, group, userId); err != nil { if err := am.validateDeleteGroup(ctx, group, userID); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
continue continue
} }
delete(account.Groups, groupID) groupIDsToDelete = append(groupIDsToDelete, groupID)
deletedGroups = append(deletedGroups, group) deletedGroups = append(deletedGroups, group)
} }
account.Network.IncSerial() err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = am.Store.SaveAccount(ctx, account); err != nil { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil {
return fmt.Errorf("failed to delete group: %w", err)
}
return nil
})
if err != nil {
return err return err
} }
for _, g := range deletedGroups { for _, group := range deletedGroups {
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
} }
return allErrors return allErrors
} }
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
}
// GroupAddPeer appends peer to the group // GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
add := true add := true
for _, itemID := range group.Peers { for _, itemID := range group.Peers {
if itemID == peerID { if itemID == peerID {
@@ -334,13 +347,27 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
group.Peers = append(group.Peers, peerID) group.Peers = append(group.Peers, peerID)
} }
account.Network.IncSerial() updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID})
if err = am.Store.SaveAccount(ctx, account); err != nil { if err != nil {
return err return err
} }
if areGroupChangesAffectPeers(account, []string{group.ID}) { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
am.updateAccountPeers(ctx, account) if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil {
return fmt.Errorf("failed to save group: %w", err)
}
return nil
})
if err != nil {
return err
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
@@ -348,41 +375,55 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupDeletePeer removes peer from the group // GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID] updated := false
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
account.Network.IncSerial()
for i, itemID := range group.Peers { for i, itemID := range group.Peers {
if itemID == peerID { if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(ctx, account); err != nil { updated = true
return err break
}
} }
} }
if areGroupChangesAffectPeers(account, []string{group.ID}) { if !updated {
am.updateAccountPeers(ctx, account) return nil
}
updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID})
if err != nil {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil {
return fmt.Errorf("failed to save group: %w", err)
}
return nil
})
if err != nil {
return err
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
} }
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error { func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user // disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration { if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser := account.Users[userID] executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if executingUser == nil { if err != nil {
return status.Errorf(status.NotFound, "user not found") return status.Errorf(status.NotFound, "user not found")
} }
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
@@ -390,32 +431,42 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string)
} }
} }
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked { if isLinked, linkedRoute := am.isGroupLinkedToRoute(ctx, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)} return &GroupLinkError{"route", string(linkedRoute.NetID)}
} }
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked { if isLinked, linkedDns := am.isGroupLinkedToDns(ctx, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name} return &GroupLinkError{"name server groups", linkedDns.Name}
} }
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked { if isLinked, linkedPolicy := am.isGroupLinkedToPolicy(ctx, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name} return &GroupLinkError{"policy", linkedPolicy.Name}
} }
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked { if isLinked, linkedSetupKey := am.isGroupLinkedToSetupKey(ctx, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name} return &GroupLinkError{"setup key", linkedSetupKey.Name}
} }
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked { if isLinked, linkedUser := am.isGroupLinkedToUser(ctx, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id} return &GroupLinkError{"user", linkedUser.Id}
} }
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) { dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name} return &GroupLinkError{"disabled DNS management groups", group.Name}
} }
if account.Settings.Extra != nil { settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) { if err != nil {
return err
}
if settings.Extra != nil {
if slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name} return &GroupLinkError{"integrated validator", group.Name}
} }
} }
@@ -424,17 +475,30 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string)
} }
// isGroupLinkedToRoute checks if a group is linked to any route in the account. // isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, accountID string, groupID string) (bool, *route.Route) {
routes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
}
for _, r := range routes { for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r return true, r
} }
} }
return false, nil return false, nil
} }
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. // isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, accountID string, groupID string) (bool, *Policy) {
policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
}
for _, policy := range policies { for _, policy := range policies {
for _, rule := range policy.Rules { for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
@@ -446,7 +510,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
} }
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
}
for _, dns := range nameServerGroups { for _, dns := range nameServerGroups {
for _, g := range dns.Groups { for _, g := range dns.Groups {
if g == groupID { if g == groupID {
@@ -454,11 +524,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
} }
} }
} }
return false, nil return false, nil
} }
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) { func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, accountID string, groupID string) (bool, *SetupKey) {
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
}
for _, setupKey := range setupKeys { for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) { if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey return true, setupKey
@@ -468,7 +545,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
} }
// isGroupLinkedToUser checks if a group is linked to any user in the account. // isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, accountID string, groupID string) (bool, *User) {
users, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
}
for _, user := range users { for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) { if slices.Contains(user.AutoGroups, groupID) {
return true, user return true, user
@@ -478,30 +561,46 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
} }
// anyGroupHasPeers checks if any of the given groups in the account have peers. // anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool { func (am *DefaultAccountManager) anyGroupHasPeers(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() { group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
return true if err != nil {
return false, err
} }
}
return false
}
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool { if group.HasPeers() {
for _, groupID := range groupIDs { return true, nil
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
return true
}
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
return true
} }
} }
return false return false, nil
}
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return false, err
}
for _, groupID := range groupIDs {
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
return true, nil
}
if linked, _ := am.isGroupLinkedToDns(ctx, accountID, groupID); linked {
return true, nil
}
if linked, _ := am.isGroupLinkedToPolicy(ctx, accountID, groupID); linked {
return true, nil
}
if linked, _ := am.isGroupLinkedToRoute(ctx, accountID, groupID); linked {
return true, nil
}
}
return false, nil
} }

View File

@@ -328,25 +328,30 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
} }
routeResource := &route.Route{ routeResource := &route.Route{
ID: "example route", ID: "example route",
Groups: []string{groupForRoute.ID}, AccountID: accountID,
Groups: []string{groupForRoute.ID},
} }
routePeerGroupResource := &route.Route{ routePeerGroupResource := &route.Route{
ID: "example route with peer groups", ID: "example route with peer groups",
AccountID: accountID,
PeerGroups: []string{groupForRoute2.ID}, PeerGroups: []string{groupForRoute2.ID},
} }
nameServerGroup := &nbdns.NameServerGroup{ nameServerGroup := &nbdns.NameServerGroup{
ID: "example name server group", ID: "example name server group",
Groups: []string{groupForNameServerGroups.ID}, AccountID: accountID,
Groups: []string{groupForNameServerGroups.ID},
} }
policy := &Policy{ policy := &Policy{
ID: "example policy", ID: "example policy",
AccountID: accountID,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: "example policy rule", ID: "example policy rule",
PolicyID: "example policy",
Destinations: []string{groupForPolicies.ID}, Destinations: []string{groupForPolicies.ID},
}, },
}, },
@@ -354,35 +359,60 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
setupKey := &SetupKey{ setupKey := &SetupKey{
Id: "example setup key", Id: "example setup key",
AccountID: accountID,
AutoGroups: []string{groupForSetupKeys.ID}, AutoGroups: []string{groupForSetupKeys.ID},
} }
user := &User{ user := &User{
Id: "example user", Id: "example user",
AccountID: accountID,
AutoGroups: []string{groupForUsers.ID}, AutoGroups: []string{groupForUsers.ID},
} }
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
account.Routes[routeResource.ID] = routeResource
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
account.Policies = append(account.Policies, policy)
account.SetupKeys[setupKey.Id] = setupKey
account.Users[user.Id] = user
err := am.Store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute) err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routeResource)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2) if err != nil {
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups) return nil, nil, err
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies) }
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
acc, err := am.Store.GetAccount(context.Background(), account.Id) err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routePeerGroupResource)
if err != nil {
return nil, nil, err
}
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nameServerGroup)
if err != nil {
return nil, nil, err
}
err = am.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
if err != nil {
return nil, nil, err
}
err = am.Store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
if err != nil {
return nil, nil, err
}
err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, user)
if err != nil {
return nil, nil, err
}
err = am.SaveGroups(context.Background(), accountID, groupAdminUserID, []*nbgroup.Group{
groupForRoute, groupForRoute2, groupForNameServerGroups, groupForPolicies,
groupForSetupKeys, groupForUsers, groupForUsers, groupForIntegration,
})
if err != nil {
return nil, nil, err
}
acc, err := am.Store.GetAccount(context.Background(), accountID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -394,24 +424,28 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
{ {
ID: "groupA", ID: "groupA",
Name: "GroupA", AccountID: account.Id,
Peers: []string{peer1.ID, peer2.ID}, Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
}, },
{ {
ID: "groupB", ID: "groupB",
Name: "GroupB", AccountID: account.Id,
Peers: []string{}, Name: "GroupB",
Peers: []string{},
}, },
{ {
ID: "groupC", ID: "groupC",
Name: "GroupC", AccountID: account.Id,
Peers: []string{peer1.ID, peer3.ID}, Name: "GroupC",
Peers: []string{peer1.ID, peer3.ID},
}, },
{ {
ID: "groupD", ID: "groupD",
Name: "GroupD", AccountID: account.Id,
Peers: []string{}, Name: "GroupD",
Peers: []string{},
}, },
}) })
assert.NoError(t, err) assert.NoError(t, err)
@@ -430,9 +464,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}() }()
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupB", ID: "groupB",
Name: "GroupB", AccountID: account.Id,
Peers: []string{peer1.ID, peer2.ID}, Name: "GroupB",
Peers: []string{peer1.ID, peer2.ID},
}) })
assert.NoError(t, err) assert.NoError(t, err)
@@ -501,10 +536,13 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
// adding a group to policy // adding a group to policy
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy", ID: "policy",
Enabled: true, AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: "rule",
PolicyID: "policy",
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupA"}, Destinations: []string{"groupA"},
@@ -524,9 +562,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}() }()
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupA", ID: "groupA",
Name: "GroupA", AccountID: account.Id,
Peers: []string{peer1.ID, peer2.ID}, Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
}) })
assert.NoError(t, err) assert.NoError(t, err)
@@ -593,9 +632,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}() }()
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupC", ID: "groupC",
Name: "GroupC", AccountID: account.Id,
Peers: []string{peer1.ID, peer3.ID}, Name: "GroupC",
Peers: []string{peer1.ID, peer3.ID},
}) })
assert.NoError(t, err) assert.NoError(t, err)
@@ -610,6 +650,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Run("saving group linked to route", func(t *testing.T) { t.Run("saving group linked to route", func(t *testing.T) {
newRoute := route.Route{ newRoute := route.Route{
ID: "route", ID: "route",
AccountID: account.Id,
Network: netip.MustParsePrefix("192.168.0.0/16"), Network: netip.MustParsePrefix("192.168.0.0/16"),
NetID: "superNet", NetID: "superNet",
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
@@ -634,9 +675,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}() }()
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupA", ID: "groupA",
Name: "GroupA", AccountID: account.Id,
Peers: []string{peer1.ID, peer2.ID, peer3.ID}, Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}) })
assert.NoError(t, err) assert.NoError(t, err)
@@ -661,9 +703,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}() }()
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupD", ID: "groupD",
Name: "GroupD", AccountID: account.Id,
Peers: []string{peer1.ID}, Name: "GroupD",
Peers: []string{peer1.ID},
}) })
assert.NoError(t, err) assert.NoError(t, err)

View File

@@ -100,13 +100,13 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
} }
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) updatedAccountSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings) resp := toAccountResponse(accountID, updatedAccountSettings)
util.WriteJSONObject(r.Context(), w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }

View File

@@ -29,7 +29,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) { GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
return account.Settings, nil return account.Settings, nil
}, },
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) {
halfYearLimit := 180 * 24 * time.Hour halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit { if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@@ -39,9 +39,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
} }
accCopy := account.Copy() return newSettings.Copy(), nil
accCopy.UpdateSettings(newSettings)
return accCopy, nil
}, },
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@@ -49,7 +49,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
return return
} }
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -132,7 +132,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
return return
} }
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -180,7 +180,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
return return
} }
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -238,7 +238,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
return return
} }
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@@ -68,7 +68,7 @@ func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler {
return nil, fmt.Errorf("unknown group name") return nil, fmt.Errorf("unknown group name")
}, },
GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { GetUserPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
return maps.Values(TestPeers), nil return maps.Values(TestPeers), nil
}, },
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error { DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {

View File

@@ -47,7 +47,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
) )
authMiddleware := middleware.NewAuthMiddleware( authMiddleware := middleware.NewAuthMiddleware(
accountManager.GetAccountFromPAT, accountManager.GetAccountInfoFromPAT,
jwtValidator.ValidateAndParse, jwtValidator.ValidateAndParse,
accountManager.MarkPATUsed, accountManager.MarkPATUsed,
accountManager.CheckUserAccessByJWTGroups, accountManager.CheckUserAccessByJWTGroups,

View File

@@ -9,9 +9,9 @@ import (
"time" "time"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbContext "github.com/netbirdio/netbird/management/server/context" nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
@@ -19,8 +19,8 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
// GetAccountFromPATFunc function // GetAccountInfoFromPATFunc function
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error)
// ValidateAndParseTokenFunc function // ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error) type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
@@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct { type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc getAccountInfoFromPAT GetAccountInfoFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc markPATUsed MarkPATUsedFunc
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
@@ -47,7 +47,7 @@ const (
) )
// NewAuthMiddleware instance constructor // NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor, markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
audience string, userIdClaim string) *AuthMiddleware { audience string, userIdClaim string) *AuthMiddleware {
if userIdClaim == "" { if userIdClaim == "" {
@@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
} }
return &AuthMiddleware{ return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT, getAccountInfoFromPAT: getAccountInfoFromPAT,
validateAndParseToken: validateAndParseToken, validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed, markPATUsed: markPATUsed,
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups, checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
@@ -116,7 +116,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
// If an error occurs, call the error handler and return an error // If an error occurs, call the error handler and return an error
if err != nil { if err != nil {
return fmt.Errorf("Error extracting token: %w", err) return fmt.Errorf("error extracting token: %w", err)
} }
validatedToken, err := m.validateAndParseToken(r.Context(), token) validatedToken, err := m.validateAndParseToken(r.Context(), token)
@@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
// CheckPATFromRequest checks if the PAT is valid // CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
token, err := getTokenFromPATRequest(auth) token, err := getTokenFromPATRequest(auth)
// If an error occurs, call the error handler and return an error
if err != nil { if err != nil {
return fmt.Errorf("Error extracting token: %w", err) return fmt.Errorf("error extracting token: %w", err)
} }
account, user, pat, err := m.getAccountFromPAT(r.Context(), token) user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
if err != nil { if err != nil {
return fmt.Errorf("invalid Token: %w", err) return fmt.Errorf("invalid Token: %w", err)
} }
@@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
claimMaps := jwt.MapClaims{} claimMaps := jwt.MapClaims{}
claimMaps[m.userIDClaim] = user.Id claimMaps[m.userIDClaim] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information. // Update the current request with the new context information.

View File

@@ -33,7 +33,8 @@ var testAccount = &server.Account{
Domain: domain, Domain: domain,
Users: map[string]*server.User{ Users: map[string]*server.User{
userID: { userID: {
Id: userID, Id: userID,
AccountID: accountID,
PATs: map[string]*server.PersonalAccessToken{ PATs: map[string]*server.PersonalAccessToken{
tokenID: { tokenID: {
ID: tokenID, ID: tokenID,
@@ -49,11 +50,11 @@ var testAccount = &server.Account{
}, },
} }
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error) {
if token == PAT { if token == PAT {
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
} }
return nil, nil, nil, fmt.Errorf("PAT invalid") return nil, nil, "", "", fmt.Errorf("PAT invalid")
} }
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) { func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
@@ -165,7 +166,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
) )
authMiddleware := NewAuthMiddleware( authMiddleware := NewAuthMiddleware(
mockGetAccountFromPAT, mockGetAccountInfoFromPAT,
mockValidateAndParseToken, mockValidateAndParseToken,
mockMarkPATUsed, mockMarkPATUsed,
mockCheckUserAccessByJWTGroups, mockCheckUserAccessByJWTGroups,

View File

@@ -120,6 +120,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
updatedNSGroup := &nbdns.NameServerGroup{ updatedNSGroup := &nbdns.NameServerGroup{
ID: nsGroupID, ID: nsGroupID,
AccountID: accountID,
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
Primary: req.Primary, Primary: req.Primary,

View File

@@ -48,8 +48,8 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error)
return peerToReturn, nil return peerToReturn, nil
} }
func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(ctx, err, w)
return return
@@ -62,11 +62,16 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
} }
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
groupsInfo := toGroupsInfo(account.Groups, peer.ID) peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) util.WriteError(ctx, err, w)
return
}
groupsInfo := toGroupsInfo(peerGroups)
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w) util.WriteError(ctx, fmt.Errorf("internal error"), w)
return return
} }
@@ -75,7 +80,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
} }
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{} req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
@@ -99,16 +104,21 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
} }
} }
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update) peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(ctx, err, w)
return return
} }
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
if err != nil {
util.WriteError(ctx, err, w)
return
}
groupMinimumInfo := toGroupsInfo(peerGroups)
validPeers, err := h.accountManager.GetValidatedPeers(account) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w) util.WriteError(ctx, fmt.Errorf("internal error"), w)
@@ -149,18 +159,11 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
case http.MethodDelete: case http.MethodDelete:
h.deletePeer(r.Context(), accountID, userID, peerID, w) h.deletePeer(r.Context(), accountID, userID, peerID, w)
return return
case http.MethodGet, http.MethodPut: case http.MethodGet:
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) h.getPeer(r.Context(), accountID, peerID, userID, w)
if err != nil { return
util.WriteError(r.Context(), err, w) case http.MethodPut:
return h.updatePeer(r.Context(), accountID, userID, peerID, w, r)
}
if r.Method == http.MethodGet {
h.getPeer(r.Context(), account, peerID, userID, w)
} else {
h.updatePeer(r.Context(), account, userID, peerID, w, r)
}
return return
default: default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
@@ -176,7 +179,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
return return
} }
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) peers, err := h.accountManager.ListPeers(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -184,19 +187,25 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
respBody := make([]*api.PeerBatch, 0, len(account.Peers)) respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range account.Peers { for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer) peerToReturn, err := h.checkPeerStatus(peer)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), accountID, peer.ID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupMinimumInfo := toGroupsInfo(peerGroups)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
} }
validPeersMap, err := h.accountManager.GetValidatedPeers(account) validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil { if err != nil {
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w) util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
@@ -259,16 +268,16 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
} }
} }
dnsDomain := h.accountManager.GetDNSDomain() validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil { if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w) util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return return
} }
customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain()) dnsDomain := h.accountManager.GetDNSDomain()
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil)
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
@@ -303,26 +312,14 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
} }
} }
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { func toGroupsInfo(groups []*nbgroup.Group) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum groupsInfo := make([]api.GroupMinimum, 0, len(groups))
groupsChecked := make(map[string]struct{})
for _, group := range groups { for _, group := range groups {
_, ok := groupsChecked[group.ID] groupsInfo = append(groupsInfo, api.GroupMinimum{
if ok { Id: group.ID,
continue Name: group.Name,
} PeersCount: len(group.Peers),
groupsChecked[group.ID] = struct{}{} })
for _, pk := range group.Peers {
if pk == peerID {
info := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
}
groupsInfo = append(groupsInfo, info)
break
}
}
} }
return groupsInfo return groupsInfo
} }

View File

@@ -39,6 +39,68 @@ const (
) )
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
}
policy := &server.Policy{
ID: "policy",
AccountID: "test_id",
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
{
ID: "rule",
Name: "rule",
Enabled: true,
Action: "accept",
Destinations: []string{"group1"},
Sources: []string{"group1"},
Bidirectional: true,
Protocol: "all",
Ports: []string{"80"},
},
},
}
srvUser := server.NewRegularUser(serviceUser)
srvUser.IsServiceUser = true
account := &server.Account{
Id: "test_id",
Domain: "hotmail.com",
Peers: peersMap,
Users: map[string]*server.User{
adminUser: server.NewAdminUser(adminUser),
regularUser: server.NewRegularUser(regularUser),
serviceUser: srvUser,
},
Groups: map[string]*nbgroup.Group{
"group1": {
ID: "group1",
AccountID: "test_id",
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
},
},
Settings: &server.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
Policies: []*server.Policy{policy},
Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
IP: net.ParseIP("100.67.0.0"),
Mask: net.IPv4Mask(255, 255, 0, 0),
},
Serial: 51,
},
}
return &PeersHandler{ return &PeersHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
@@ -64,77 +126,37 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
} }
return p, nil return p, nil
}, },
GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { GetUserPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
return peers, nil return peers, nil
}, },
ListPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
return peers, nil
},
GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
peersID := make([]string, len(peers))
for _, peer := range peers {
peersID = append(peersID, peer.ID)
}
return []*nbgroup.Group{
{
ID: "group1",
AccountID: accountID,
Name: "group1",
Issued: "api",
Peers: peersID,
},
}, nil
},
GetDNSDomainFunc: func() string { GetDNSDomainFunc: func() string {
return "netbird.selfhosted" return "netbird.selfhosted"
}, },
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil return claims.AccountId, claims.UserId, nil
}, },
GetAccountFunc: func(ctx context.Context, accountID string) (*server.Account, error) {
return account, nil
},
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
}
policy := &server.Policy{
ID: "policy",
AccountID: accountID,
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
{
ID: "rule",
Name: "rule",
Enabled: true,
Action: "accept",
Destinations: []string{"group1"},
Sources: []string{"group1"},
Bidirectional: true,
Protocol: "all",
Ports: []string{"80"},
},
},
}
srvUser := server.NewRegularUser(serviceUser)
srvUser.IsServiceUser = true
account := &server.Account{
Id: accountID,
Domain: "hotmail.com",
Peers: peersMap,
Users: map[string]*server.User{
adminUser: server.NewAdminUser(adminUser),
regularUser: server.NewRegularUser(regularUser),
serviceUser: srvUser,
},
Groups: map[string]*nbgroup.Group{
"group1": {
ID: "group1",
AccountID: accountID,
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
},
},
Settings: &server.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
Policies: []*server.Policy{policy},
Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
IP: net.ParseIP("100.67.0.0"),
Mask: net.IPv4Mask(255, 255, 0, 0),
},
Serial: 51,
},
}
return account, nil return account, nil
}, },
HasConnectedChannelFunc: func(peerID string) bool { HasConnectedChannelFunc: func(peerID string) bool {

View File

@@ -130,6 +130,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
policy := server.Policy{ policy := server.Policy{
ID: policyID, ID: policyID,
AccountID: accountID,
Name: req.Name, Name: req.Name,
Enabled: req.Enabled, Enabled: req.Enabled,
Description: req.Description, Description: req.Description,

View File

@@ -163,13 +163,16 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
} }
} }
isUpdate := postureChecksID != ""
postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID) postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
postureChecks.AccountID = accountID
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil { if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks, isUpdate); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }

View File

@@ -40,7 +40,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
} }
return p, nil return p, nil
}, },
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error { SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks, _ bool) error {
postureChecks.ID = "postureCheck" postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks testPostureChecks[postureChecks.ID] = postureChecks

View File

@@ -4,6 +4,8 @@ import (
"context" "context"
"errors" "errors"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
@@ -56,13 +58,15 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId
if len(groups) == 0 { if len(groups) == 0 {
return true, nil return true, nil
} }
accountsGroups, err := am.ListGroups(ctx, accountId)
accountGroups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountId)
if err != nil { if err != nil {
return false, err return false, err
} }
for _, group := range groups { for _, group := range groups {
var found bool var found bool
for _, accountGroup := range accountsGroups { for _, accountGroup := range accountGroups {
if accountGroup.ID == group { if accountGroup.ID == group {
found = true found = true
break break
@@ -76,6 +80,31 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId
return true, nil return true, nil
} }
func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) { func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
groupsMap := make(map[string]*nbgroup.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
peers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peersMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peersMap[peer.ID] = peer
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra)
} }

View File

@@ -461,7 +461,7 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
grpc.WithBlock(), grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{ grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 10 * time.Second, Time: 10 * time.Second,
Timeout: 2 * time.Second, Timeout: 200 * time.Second,
})) }))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err

View File

@@ -22,16 +22,17 @@ import (
) )
type MockAccountManager struct { type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error) GetOrCreateAccountIDByUserFunc func(ctx context.Context, userId, domain string) (string, error)
GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error) GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) GetUserPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
ListPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
@@ -45,16 +46,16 @@ type MockAccountManager struct {
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*group.Group, error)
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) GetAccountInfoFromPATFunc func(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error)
MarkPATUsedFunc func(ctx context.Context, pat string) error MarkPATUsedFunc func(ctx context.Context, pat string) error
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
@@ -89,15 +90,15 @@ type MockAccountManager struct {
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error) GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error)
LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) SyncPeerFunc func(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error) GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager GetIdpManagerFunc func() idp.Manager
@@ -123,7 +124,7 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str
if am.SyncAndMarkPeerFunc != nil { if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
} }
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncAndMarkPeer is not implemented")
} }
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error { func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
@@ -131,7 +132,12 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
panic("implement me") panic("implement me")
} }
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
account, err := am.GetAccountFunc(ctx, accountID)
if err != nil {
return nil, err
}
approvedPeers := make(map[string]struct{}) approvedPeers := make(map[string]struct{})
for id := range account.Peers { for id := range account.Peers {
approvedPeers[id] = struct{}{} approvedPeers[id] = struct{}{}
@@ -171,16 +177,16 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID,
return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented") return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
} }
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface // GetOrCreateAccountIDByUser mock implementation of GetOrCreateAccountIDByUser from server.AccountManager interface
func (am *MockAccountManager) GetOrCreateAccountByUser( func (am *MockAccountManager) GetOrCreateAccountIDByUser(
ctx context.Context, userId, domain string, ctx context.Context, userId, domain string,
) (*server.Account, error) { ) (string, error) {
if am.GetOrCreateAccountByUserFunc != nil { if am.GetOrCreateAccountIDByUserFunc != nil {
return am.GetOrCreateAccountByUserFunc(ctx, userId, domain) return am.GetOrCreateAccountIDByUserFunc(ctx, userId, domain)
} }
return nil, status.Errorf( return "", status.Errorf(
codes.Unimplemented, codes.Unimplemented,
"method GetOrCreateAccountByUser is not implemented", "method GetOrCreateAccountIDByUser is not implemented",
) )
} }
@@ -222,19 +228,19 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId,
} }
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error { func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error {
if am.MarkPeerConnectedFunc != nil { if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP) return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP)
} }
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
} }
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface // GetAccountInfoFromPAT mock implementation of GetAccountInfoFromPAT from server.AccountManager interface
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { func (am *MockAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error) {
if am.GetAccountFromPATFunc != nil { if am.GetAccountInfoFromPATFunc != nil {
return am.GetAccountFromPATFunc(ctx, pat) return am.GetAccountInfoFromPATFunc(ctx, token)
} }
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetAccountInfoFromPAT is not implemented")
} }
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface // DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
@@ -354,14 +360,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented") return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
} }
// ListGroups mock implementation of ListGroups from server.AccountManager interface
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
if am.ListGroupsFunc != nil {
return am.ListGroupsFunc(ctx, accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
}
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
if am.GroupAddPeerFunc != nil { if am.GroupAddPeerFunc != nil {
@@ -626,12 +624,12 @@ func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, cl
return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented") return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented")
} }
// GetPeers mocks GetPeers of the AccountManager interface // GetUserPeers mocks GetUserPeers of the AccountManager interface
func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { func (am *MockAccountManager) GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
if am.GetPeersFunc != nil { if am.GetUserPeersFunc != nil {
return am.GetPeersFunc(ctx, accountID, userID) return am.GetUserPeersFunc(ctx, accountID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetUserPeers is not implemented")
} }
// GetDNSDomain mocks GetDNSDomain of the AccountManager interface // GetDNSDomain mocks GetDNSDomain of the AccountManager interface
@@ -675,7 +673,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us
} }
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface // UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) {
if am.UpdateAccountSettingsFunc != nil { if am.UpdateAccountSettingsFunc != nil {
return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings) return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings)
} }
@@ -691,9 +689,9 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLo
} }
// SyncPeer mocks SyncPeer of the AccountManager interface // SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
if am.SyncPeerFunc != nil { if am.SyncPeerFunc != nil {
return am.SyncPeerFunc(ctx, sync, account) return am.SyncPeerFunc(ctx, sync, accountID)
} }
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
} }
@@ -739,9 +737,9 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
} }
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface // SavePostureChecks mocks SavePostureChecks of the AccountManager interface
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error {
if am.SavePostureChecksFunc != nil { if am.SavePostureChecksFunc != nil {
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks) return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks, isUpdate)
} }
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
} }
@@ -840,3 +838,19 @@ func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented")
} }
// GetPeerGroups mocks GetPeerGroups of the AccountManager interface
func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*group.Group, error) {
if am.GetPeerGroupsFunc != nil {
return am.GetPeerGroupsFunc(ctx, accountID, peerID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented")
}
// ListPeers mocks ListPeers of the AccountManager interface
func (am *MockAccountManager) ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
if am.ListPeersFunc != nil {
return am.ListPeersFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListPeers is not implemented")
}

View File

@@ -3,7 +3,9 @@ package server
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"regexp" "regexp"
"slices"
"unicode/utf8" "unicode/utf8"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -24,26 +26,31 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") return nil, status.NewUserNotPartOfAccountError()
} }
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID) if user.IsRegularUser() {
return nil, status.NewUnauthorizedToViewNSGroupsError()
}
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
} }
// CreateNameServerGroup creates and saves a new nameserver group // CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
newNSGroup := &nbdns.NameServerGroup{ newNSGroup := &nbdns.NameServerGroup{
ID: xid.New().String(), ID: xid.New().String(),
AccountID: accountID,
Name: name, Name: name,
Description: description, Description: description,
NameServers: nameServerList, NameServers: nameServerList,
@@ -54,92 +61,136 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
SearchDomainsEnabled: searchDomainEnabled, SearchDomainsEnabled: searchDomainEnabled,
} }
err = validateNameServerGroup(false, newNSGroup, account) err = am.validateNameServerGroup(ctx, accountID, newNSGroup)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if account.NameServerGroups == nil { updateAccountPeers, err := am.anyGroupHasPeers(ctx, accountID, newNSGroup.Groups)
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup) if err != nil {
}
account.NameServerGroups[newNSGroup.ID] = newNSGroup
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err return nil, err
} }
if anyGroupHasPeers(account, newNSGroup.Groups) { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
am.updateAccountPeers(ctx, account) if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if err = transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup); err != nil {
return fmt.Errorf("failed to create nameserver group: %w", err)
}
return nil
})
if err != nil {
return nil, err
} }
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return newNSGroup.Copy(), nil return newNSGroup.Copy(), nil
} }
// SaveNameServerGroup saves nameserver group // SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if nsGroupToSave == nil { if nsGroupToSave == nil {
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
} }
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
err = validateNameServerGroup(true, nsGroupToSave, account) if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
oldNSGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
if err != nil {
return err
}
nsGroupToSave.AccountID = accountID
if err = am.validateNameServerGroup(ctx, accountID, nsGroupToSave); err != nil {
return err
}
updateAccountPeers, err := am.areNameServerGroupChangesAffectPeers(ctx, nsGroupToSave, oldNSGroup)
if err != nil { if err != nil {
return err return err
} }
oldNSGroup := account.NameServerGroups[nsGroupToSave.ID] err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
account.Network.IncSerial() if err = transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave); err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil { return fmt.Errorf("failed to update nameserver group: %w", err)
}
return nil
})
if err != nil {
return err return err
} }
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil return nil
} }
// DeleteNameServerGroup deletes nameserver group with nsGroupID // DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
nsGroup := account.NameServerGroups[nsGroupID] if user.AccountID != accountID {
if nsGroup == nil { return status.NewUserNotPartOfAccountError()
return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID)
} }
delete(account.NameServerGroups, nsGroupID)
account.Network.IncSerial() nsGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
if err = am.Store.SaveAccount(ctx, account); err != nil { if err != nil {
return err return err
} }
if anyGroupHasPeers(account, nsGroup.Groups) { updateAccountPeers, err := am.anyGroupHasPeers(ctx, accountID, nsGroup.Groups)
am.updateAccountPeers(ctx, account) if err != nil {
return err
} }
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if err = transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID); err != nil {
return fmt.Errorf("failed to delete nameserver group: %w", err)
}
return nil
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil return nil
} }
@@ -150,39 +201,44 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewUnauthorizedToViewNSGroupsError()
} }
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
} }
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error { func (am *DefaultAccountManager) validateNameServerGroup(ctx context.Context, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
nsGroupID := ""
if existingGroup {
nsGroupID = nameserverGroup.ID
_, found := account.NameServerGroups[nsGroupID]
if !found {
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
}
}
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled) err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
if err != nil { if err != nil {
return err return err
} }
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
if err != nil {
return err
}
err = validateNSList(nameserverGroup.NameServers) err = validateNSList(nameserverGroup.NameServers)
if err != nil { if err != nil {
return err return err
} }
err = validateGroups(nameserverGroup.Groups, account.Groups) nsServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups)
if err != nil {
return err
}
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
err = validateGroups(nameserverGroup.Groups, groups)
if err != nil { if err != nil {
return err return err
} }
@@ -190,6 +246,24 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
return nil return nil
} }
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(ctx context.Context, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false, nil
}
hasPeers, err := am.anyGroupHasPeers(ctx, newNSGroup.AccountID, newNSGroup.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return am.anyGroupHasPeers(ctx, oldNSGroup.AccountID, oldNSGroup.Groups)
}
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error { func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
if !primary && len(domains) == 0 { if !primary && len(domains) == 0 {
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+ return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
@@ -213,14 +287,14 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
return nil return nil
} }
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error { func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error {
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" { if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar) return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
} }
for _, nsGroup := range nsGroupMap { for _, nsGroup := range groups {
if name == nsGroup.Name && nsGroup.ID != nsGroupID { if name == nsGroup.Name && nsGroup.ID != nsGroupID {
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name) return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name)
} }
} }
@@ -228,14 +302,14 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na
} }
func validateNSList(list []nbdns.NameServer) error { func validateNSList(list []nbdns.NameServer) error {
nsListLenght := len(list) nsListLength := len(list)
if nsListLenght == 0 || nsListLenght > 3 { if nsListLength == 0 || nsListLength > 3 {
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list)) return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list))
} }
return nil return nil
} }
func validateGroups(list []string, groups map[string]*nbgroup.Group) error { func validateGroups(list []string, groups []*nbgroup.Group) error {
if len(list) == 0 { if len(list) == 0 {
return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty") return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty")
} }
@@ -244,13 +318,8 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error {
if id == "" { if id == "" {
return status.Errorf(status.InvalidArgument, "group ID should not be empty string") return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
} }
found := false
for groupID := range groups { found := slices.ContainsFunc(groups, func(group *nbgroup.Group) bool { return group.ID == id })
if id == groupID {
found = true
break
}
}
if !found { if !found {
return status.Errorf(status.InvalidArgument, "group id %s not found", id) return status.Errorf(status.InvalidArgument, "group id %s not found", id)
} }
@@ -277,11 +346,3 @@ func validateDomain(domain string) error {
return nil return nil
} }
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false
}
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
}

View File

@@ -6,6 +6,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -381,14 +382,14 @@ func TestCreateNameServerGroup(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestNSAccount(t, am) accountID, err := initTestNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
outNSGroup, err := am.CreateNameServerGroup( outNSGroup, err := am.CreateNameServerGroup(
context.Background(), context.Background(),
account.Id, accountID,
testCase.inputArgs.name, testCase.inputArgs.name,
testCase.inputArgs.description, testCase.inputArgs.description,
testCase.inputArgs.nameServers, testCase.inputArgs.nameServers,
@@ -408,7 +409,7 @@ func TestCreateNameServerGroup(t *testing.T) {
// assign generated ID // assign generated ID
testCase.expectedNSGroup.ID = outNSGroup.ID testCase.expectedNSGroup.ID = outNSGroup.ID
testCase.expectedNSGroup.AccountID = accountID
if !testCase.expectedNSGroup.IsEqual(outNSGroup) { if !testCase.expectedNSGroup.IsEqual(outNSGroup) {
t.Errorf("new nameserver group didn't match expected ns group:\nGot %#v\nExpected:%#v\n", outNSGroup, testCase.expectedNSGroup) t.Errorf("new nameserver group didn't match expected ns group:\nGot %#v\nExpected:%#v\n", outNSGroup, testCase.expectedNSGroup)
} }
@@ -609,20 +610,16 @@ func TestSaveNameServerGroup(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestNSAccount(t, am) accountID, err := initTestNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup testCase.existingNSGroup.AccountID = accountID
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testCase.existingNSGroup)
err = am.Store.SaveAccount(context.Background(), account) require.NoError(t, err, "failed to save existing nameserver group")
if err != nil {
t.Error("account should be saved")
}
var nsGroupToSave *nbdns.NameServerGroup var nsGroupToSave *nbdns.NameServerGroup
if !testCase.skipCopying { if !testCase.skipCopying {
nsGroupToSave = testCase.existingNSGroup.Copy() nsGroupToSave = testCase.existingNSGroup.Copy()
@@ -651,22 +648,17 @@ func TestSaveNameServerGroup(t *testing.T) {
} }
} }
err = am.SaveNameServerGroup(context.Background(), account.Id, userID, nsGroupToSave) err = am.SaveNameServerGroup(context.Background(), accountID, userID, nsGroupToSave)
testCase.errFunc(t, err) testCase.errFunc(t, err)
if !testCase.shouldCreate { if !testCase.shouldCreate {
return return
} }
account, err = am.Store.GetAccount(context.Background(), account.Id) savedNSGroup, err := am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testCase.expectedNSGroup.ID)
if err != nil { require.NoError(t, err, "failed to get saved nameserver group")
t.Fatal(err)
}
savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID]
require.True(t, saved)
testCase.expectedNSGroup.AccountID = accountID
if !testCase.expectedNSGroup.IsEqual(savedNSGroup) { if !testCase.expectedNSGroup.IsEqual(savedNSGroup) {
t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", savedNSGroup, testCase.expectedNSGroup) t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", savedNSGroup, testCase.expectedNSGroup)
} }
@@ -703,32 +695,25 @@ func TestDeleteNameServerGroup(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestNSAccount(t, am) accountID, err := initTestNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
account.NameServerGroups[testingNSGroup.ID] = testingNSGroup testingNSGroup.AccountID = accountID
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testingNSGroup)
require.NoError(t, err, "failed to save nameserver group")
err = am.Store.SaveAccount(context.Background(), account) err = am.DeleteNameServerGroup(context.Background(), accountID, testingNSGroup.ID, userID)
if err != nil {
t.Error("failed to save account")
}
err = am.DeleteNameServerGroup(context.Background(), account.Id, testingNSGroup.ID, userID)
if err != nil { if err != nil {
t.Error("deleting nameserver group failed with error: ", err) t.Error("deleting nameserver group failed with error: ", err)
} }
savedAccount, err := am.Store.GetAccount(context.Background(), account.Id) _, err = am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testingNSGroup.ID)
if err != nil { require.NotNil(t, err)
t.Error("failed to retrieve saved account with error: ", err) sErr, ok := status.FromError(err)
} require.True(t, ok, "error should be a status error")
assert.Equal(t, status.NotFound, sErr.Type(), "nameserver group shouldn't be found after delete")
_, found := savedAccount.NameServerGroups[testingNSGroup.ID]
if found {
t.Error("nameserver group shouldn't be found after delete")
}
} }
func TestGetNameServerGroup(t *testing.T) { func TestGetNameServerGroup(t *testing.T) {
@@ -738,12 +723,12 @@ func TestGetNameServerGroup(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestNSAccount(t, am) accountID, err := initTestNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
foundGroup, err := am.GetNameServerGroup(context.Background(), account.Id, testUserID, existingNSGroupID) foundGroup, err := am.GetNameServerGroup(context.Background(), accountID, testUserID, existingNSGroupID)
if err != nil { if err != nil {
t.Error("getting existing nameserver group failed with error: ", err) t.Error("getting existing nameserver group failed with error: ", err)
} }
@@ -752,7 +737,7 @@ func TestGetNameServerGroup(t *testing.T) {
t.Error("got a nil group while getting nameserver group with ID") t.Error("got a nil group while getting nameserver group with ID")
} }
_, err = am.GetNameServerGroup(context.Background(), account.Id, testUserID, "not existing") _, err = am.GetNameServerGroup(context.Background(), accountID, testUserID, "not existing")
if err == nil { if err == nil {
t.Error("getting not existing nameserver group should return error, got nil") t.Error("getting not existing nameserver group should return error, got nil")
} }
@@ -784,8 +769,12 @@ func createNSStore(t *testing.T) (Store, error) {
return store, nil return store, nil
} }
func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
t.Helper() t.Helper()
accountID := "testingAcc"
userID := testUserID
domain := "example.com"
peer1 := &nbpeer.Peer{ peer1 := &nbpeer.Peer{
Key: nsGroupPeer1Key, Key: nsGroupPeer1Key,
Name: "test-host1@netbird.io", Name: "test-host1@netbird.io",
@@ -816,6 +805,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
} }
existingNSGroup := nbdns.NameServerGroup{ existingNSGroup := nbdns.NameServerGroup{
ID: existingNSGroupID, ID: existingNSGroupID,
AccountID: accountID,
Name: existingNSGroupName, Name: existingNSGroupName,
Description: "", Description: "",
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
@@ -834,42 +824,42 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
Enabled: true, Enabled: true,
} }
accountID := "testingAcc" err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
userID := testUserID
domain := "example.com"
account := newAccountWithId(context.Background(), accountID, userID, domain)
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup
newGroup1 := &nbgroup.Group{
ID: group1ID,
Name: group1ID,
}
newGroup2 := &nbgroup.Group{
ID: group2ID,
Name: group2ID,
}
account.Groups[newGroup1.ID] = newGroup1
account.Groups[newGroup2.ID] = newGroup2
err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return "", err
}
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &existingNSGroup)
if err != nil {
return "", err
}
err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*nbgroup.Group{
{
ID: group1ID,
AccountID: accountID,
Name: group1ID,
},
{
ID: group2ID,
AccountID: accountID,
Name: group2ID,
},
})
if err != nil {
return "", err
} }
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer1) _, _, _, err = am.AddPeer(context.Background(), "", userID, peer1)
if err != nil { if err != nil {
return nil, err return "", err
} }
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer2) _, _, _, err = am.AddPeer(context.Background(), "", userID, peer2)
if err != nil { if err != nil {
return nil, err return "", err
} }
return account, nil return accountID, nil
} }
func TestValidateDomain(t *testing.T) { func TestValidateDomain(t *testing.T) {

View File

@@ -11,6 +11,7 @@ import (
"sync" "sync"
"time" "time"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -50,34 +51,54 @@ type PeerLogin struct {
ConnectionIP net.IP ConnectionIP net.IP
} }
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // ListPeers returns a list of peers under the given account.
func (am *DefaultAccountManager) ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
return am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
}
// GetUserPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin. // the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
peers := make([]*nbpeer.Peer, 0) peers := make([]*nbpeer.Peer, 0)
peersMap := make(map[string]*nbpeer.Peer) peersMap := make(map[string]*nbpeer.Peer)
regularUser := !user.HasAdminPower() && !user.IsServiceUser if user.IsRegularUser() && settings.RegularUsersViewBlocked {
if regularUser && account.Settings.RegularUsersViewBlocked {
return peers, nil return peers, nil
} }
for _, peer := range account.Peers { accountPeers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
if regularUser && user.Id != peer.UserID { if err != nil {
return nil, err
}
for _, peer := range accountPeers {
if user.IsRegularUser() && user.Id != peer.UserID {
// only display peers that belong to the current user if the current user is not an admin // only display peers that belong to the current user if the current user is not an admin
continue continue
} }
@@ -86,10 +107,15 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
peersMap[peer.ID] = p peersMap[peer.ID] = p
} }
if !regularUser { if user.IsAdminOrServiceUser() {
return peers, nil return peers, nil
} }
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, fmt.Errorf(errGetAccountFmt, err)
}
// fetch all the peers that have access to the user's peers // fetch all the peers that have access to the user's peers
for _, peer := range peers { for _, peer := range peers {
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap) aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
@@ -107,37 +133,42 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
} }
// MarkPeerConnected marks peer as connected (true) or disconnected (false) // MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error { func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
peer, err := account.FindPeerByPubKey(peerPubKey) peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peerPubKey)
if err != nil { if err != nil {
return err return err
} }
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account) settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, accountID)
if err != nil { if err != nil {
return err return err
} }
if peer.AddedWithSSOLogin() { if peer.AddedWithSSOLogin() {
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account) am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
} }
if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, account) am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
} }
} }
if expired { if expired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from // we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again. // the expired one. Here we notify them that connection is now allowed again.
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
} }
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) { func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
oldStatus := peer.Status.Copy() oldStatus := peer.Status.Copy()
newStatus := oldStatus newStatus := oldStatus
newStatus.LastSeen = time.Now().UTC() newStatus.LastSeen = time.Now().UTC()
@@ -157,16 +188,14 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
peer.Location.CountryCode = location.Country.ISOCode peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID peer.Location.GeoNameID = location.City.GeonameID
err = am.Store.SavePeerLocation(account.Id, peer) err = am.Store.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
} }
} }
} }
account.UpdatePeer(peer) err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus)
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -176,39 +205,50 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated. // UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
peer := account.GetPeer(update.ID) if user.AccountID != accountID {
if peer == nil { return nil, status.NewUserNotPartOfAccountError()
return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID)
} }
update, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, update.ID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, update.ID)
if err != nil {
return nil, err
}
update, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra)
if err != nil {
return nil, err
}
var sshChanged, peerLabelChanged, loginExpirationChanged, inactivityExpirationChanged bool
if peer.SSHEnabled != update.SSHEnabled { if peer.SSHEnabled != update.SSHEnabled {
peer.SSHEnabled = update.SSHEnabled peer.SSHEnabled = update.SSHEnabled
event := activity.PeerSSHEnabled sshChanged = true
if !update.SSHEnabled {
event = activity.PeerSSHDisabled
}
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
} }
peerLabelUpdated := peer.Name != update.Name if peer.Name != update.Name {
if peerLabelUpdated {
peer.Name = update.Name peer.Name = update.Name
peerLabelChanged = true
existingLabels := account.getPeerDNSLabels() existingLabels, err := am.getPeerDNSLabels(ctx, accountID)
if err != nil {
return nil, err
}
newLabel, err := getPeerHostLabel(peer.Name, existingLabels) newLabel, err := getPeerHostLabel(peer.Name, existingLabels)
if err != nil { if err != nil {
@@ -216,134 +256,107 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
} }
peer.DNSLabel = newLabel peer.DNSLabel = newLabel
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain()))
} }
if peer.LoginExpirationEnabled != update.LoginExpirationEnabled { if peer.LoginExpirationEnabled != update.LoginExpirationEnabled {
if !peer.AddedWithSSOLogin() { if !peer.AddedWithSSOLogin() {
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated")
} }
peer.LoginExpirationEnabled = update.LoginExpirationEnabled peer.LoginExpirationEnabled = update.LoginExpirationEnabled
loginExpirationChanged = true
}
if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled {
if !peer.AddedWithSSOLogin() {
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the inactivity expiration can't be updated")
}
peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled
inactivityExpirationChanged = true
}
if err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil {
return nil, err
}
if sshChanged {
event := activity.PeerSSHEnabled
if !peer.SSHEnabled {
event = activity.PeerSSHDisabled
}
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
}
if peerLabelChanged {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain()))
am.updateAccountPeers(ctx, accountID)
}
if loginExpirationChanged {
event := activity.PeerLoginExpirationEnabled event := activity.PeerLoginExpirationEnabled
if !update.LoginExpirationEnabled { if !peer.LoginExpirationEnabled {
event = activity.PeerLoginExpirationDisabled event = activity.PeerLoginExpirationDisabled
} }
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account) am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
} }
} }
if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled { if inactivityExpirationChanged {
if !peer.AddedWithSSOLogin() {
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated")
}
peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled
event := activity.PeerInactivityExpirationEnabled event := activity.PeerInactivityExpirationEnabled
if !update.InactivityExpirationEnabled { if !peer.InactivityExpirationEnabled {
event = activity.PeerInactivityExpirationDisabled event = activity.PeerInactivityExpirationDisabled
} }
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, account) am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
} }
} }
account.UpdatePeer(peer)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
if peerLabelUpdated {
am.updateAccountPeers(ctx, account)
}
return peer, nil return peer, nil
} }
// deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock
func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Account, peerIDs []string, userID string) error {
// the first loop is needed to ensure all peers present under the account before modifying, otherwise
// we might have some inconsistencies
peers := make([]*nbpeer.Peer, 0, len(peerIDs))
for _, peerID := range peerIDs {
peer := account.GetPeer(peerID)
if peer == nil {
return status.Errorf(status.NotFound, "peer %s not found", peerID)
}
peers = append(peers, peer)
}
// the 2nd loop performs the actual modification
for _, peer := range peers {
err := am.integratedPeerValidator.PeerDeleted(ctx, account.Id, peer.ID)
if err != nil {
return err
}
account.DeletePeer(peer.ID)
am.peersUpdateManager.SendUpdate(ctx, peer.ID,
&UpdateMessage{
Update: &proto.SyncResponse{
// fill those field for backward compatibility
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
// new field
NetworkMap: &proto.NetworkMap{
Serial: account.Network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
},
},
NetworkMap: &NetworkMap{},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
}
return nil
}
// DeletePeer removes peer from the account by its IP // DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, LockingStrengthShare, peerID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
updateAccountPeers := isPeerInActiveGroup(account, peerID) if peerAccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
err = am.deletePeers(ctx, account, []string{peerID}, userID) updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, peerID)
if err != nil { if err != nil {
return err return err
} }
err = am.Store.SaveAccount(ctx, account) var peer *nbpeer.Peer
if err != nil { var addPeerRemovedEvents []func()
return err
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
peer, err = transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
return fmt.Errorf("failed to get peer to delete: %w", err)
}
addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
}
return nil
})
for _, addPeerRemovedEvent := range addPeerRemovedEvents {
addPeerRemovedEvent()
} }
if updateAccountPeers { if updateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
@@ -405,7 +418,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
addedByUser := false addedByUser := false
if len(userID) > 0 { if len(userID) > 0 {
addedByUser = true addedByUser = true
accountID, err = am.Store.GetAccountIDByUserID(userID) accountID, err = am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID)
} else { } else {
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey) accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
} }
@@ -436,12 +449,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
var newPeer *nbpeer.Peer var newPeer *nbpeer.Peer
var groupsToAdd []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
var setupKeyID string var setupKeyID string
var setupKeyName string var setupKeyName string
var ephemeral bool var ephemeral bool
var groupsToAdd []string
if addedByUser { if addedByUser {
user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID) user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
if err != nil { if err != nil {
@@ -550,7 +563,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return fmt.Errorf("failed to add peer to account: %w", err) return fmt.Errorf("failed to add peer to account: %w", err)
} }
err = transaction.IncrementNetworkSerial(ctx, accountID) err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -584,30 +597,16 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
unlock() unlock()
unlock = nil unlock = nil
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, newPeer.ID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
}
allGroup, err := account.GetGroupAll()
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
}
groupsToAdd = append(groupsToAdd, allGroup.ID)
if areGroupChangesAffectPeers(account, groupsToAdd) {
am.updateAccountPeers(ctx, account)
}
approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
postureChecks := am.getPeerPostureChecks(account, newPeer) if updateAccountPeers {
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) am.updateAccountPeers(ctx, accountID)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) }
return newPeer, networkMap, postureChecks, nil
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
} }
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) { func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
@@ -630,14 +629,14 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc
} }
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, sync.WireGuardPubKey)
if err != nil { if err != nil {
return nil, nil, nil, status.NewPeerNotRegisteredError() return nil, nil, nil, status.NewPeerNotRegisteredError()
} }
if peer.UserID != "" { if peer.UserID != "" {
user, err := account.FindUser(peer.UserID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -648,48 +647,38 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
} }
} }
if peerLoginExpired(ctx, peer, account.Settings) { settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
if peerLoginExpired(ctx, peer, settings) {
return nil, nil, nil, status.NewPeerLoginExpiredError() return nil, nil, nil, status.NewPeerLoginExpiredError()
} }
peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, peer.ID)
if err != nil {
return nil, nil, nil, err
}
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupList, settings.Extra)
if err != nil {
return nil, nil, nil, err
}
updated := peer.UpdateMetaIfNew(sync.Meta) updated := peer.UpdateMetaIfNew(sync.Meta)
if updated { if updated {
err = am.Store.SavePeer(ctx, account.Id, peer) err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
if sync.UpdateAccountPeers {
am.updateAccountPeers(ctx, account)
}
} }
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) if isStatusChanged || (updated && sync.UpdateAccountPeers) {
if err != nil { am.updateAccountPeers(ctx, accountID)
return nil, nil, nil, err
} }
var postureChecks []*posture.Checks return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
if peerNotValid {
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
}
return peer, emptyMap, postureChecks, nil
}
if isStatusChanged {
am.updateAccountPeers(ctx, account)
}
validPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, nil, nil, err
}
postureChecks = am.getPeerPostureChecks(account, peer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
} }
// LoginPeer logs in or registers a peer. // LoginPeer logs in or registers a peer.
@@ -764,7 +753,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
} }
} }
groups, err := am.Store.GetAccountGroups(ctx, accountID) groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -795,7 +784,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
} }
if shouldStorePeer { if shouldStorePeer {
err = am.Store.SavePeer(ctx, accountID, peer) err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -804,16 +793,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
unlockPeer() unlockPeer()
unlockPeer = nil unlockPeer = nil
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if updateRemotePeers || isStatusChanged { if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
} }
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO // checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO
@@ -845,21 +829,33 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil return nil
} }
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
var postureChecks []*posture.Checks
if isRequiresApproval { if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
emptyMap := &NetworkMap{ emptyMap := &NetworkMap{
Network: account.Network.Copy(), Network: network.Copy(),
} }
return peer, emptyMap, nil, nil return peer, emptyMap, nil, nil
} }
approvedPeersMap, err := am.GetValidatedPeers(account) account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
if err != nil {
return nil, nil, nil, err
}
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, peer.ID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
postureChecks = am.getPeerPostureChecks(account, peer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
@@ -873,7 +869,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
// If peer was expired before and if it reached this point, it is re-authenticated. // If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer. // UserID is present, meaning that JWT validation passed successfully in the API layer.
peer = peer.UpdateLastLogin() peer = peer.UpdateLastLogin()
err = am.Store.SavePeer(ctx, peer.AccountID, peer) err = am.Store.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer)
if err != nil { if err != nil {
return err return err
} }
@@ -920,45 +916,51 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
// GetPeer for a given accountID, peerID and userID error if not found. // GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { if user.IsRegularUser() && settings.RegularUsersViewBlocked {
return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID) return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID)
} }
peer := account.GetPeer(peerID) peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if peer == nil { if err != nil {
return nil, status.Errorf(status.NotFound, "peer with %s not found under account %s", peerID, accountID) return nil, err
} }
// if admin or user owns this peer, return peer // if admin or user owns this peer, return peer
if user.HasAdminPower() || user.IsServiceUser || peer.UserID == userID { if user.IsAdminOrServiceUser() || peer.UserID == userID {
return peer, nil return peer, nil
} }
// it is also possible that user doesn't own the peer but some of his peers have access to it, // it is also possible that user doesn't own the peer but some of his peers have access to it,
// this is a valid case, show the peer as well. // this is a valid case, show the peer as well.
userPeers, err := account.FindUserPeers(userID) userPeers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, fmt.Errorf(errGetAccountFmt, err)
}
for _, p := range userPeers { for _, p := range userPeers {
aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap) aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap)
for _, aclPeer := range aclPeers { for _, aclPeer := range aclPeers {
@@ -973,7 +975,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
// updateAccountPeers updates all peers that belong to an account. // updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers. // Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) { func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) {
start := time.Now() start := time.Now()
defer func() { defer func() {
if am.metrics != nil { if am.metrics != nil {
@@ -981,9 +983,15 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
} }
}() }()
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
return
}
peers := account.GetPeers() peers := account.GetPeers()
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err) log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err)
return return
@@ -1007,7 +1015,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
defer wg.Done() defer wg.Done()
defer func() { <-semaphore }() defer func() { <-semaphore }()
postureChecks := am.getPeerPostureChecks(account, p) postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err)
return
}
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
@@ -1017,6 +1030,236 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
wg.Wait() wg.Wait()
} }
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are connected.
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
return 0, false
}
if len(peersWithExpiry) == 0 {
return 0, false
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return 0, false
}
var nextExpiry *time.Duration
for _, peer := range peersWithExpiry {
// consider only connected peers because others will require login on connecting to the management server
if peer.Status.LoginExpired || !peer.Status.Connected {
continue
}
_, duration := peer.LoginExpired(settings.PeerLoginExpiration)
if nextExpiry == nil || duration < *nextExpiry {
// if expiration is below 1s return 1s duration
// this avoids issues with ticker that can't be set to < 0
if duration < time.Second {
return time.Second, true
}
nextExpiry = &duration
}
}
if nextExpiry == nil {
return 0, false
}
return *nextExpiry, true
}
// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are not connected.
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
return 0, false
}
if len(peersWithInactivity) == 0 {
return 0, false
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return 0, false
}
var nextExpiry *time.Duration
for _, peer := range peersWithInactivity {
if peer.Status.LoginExpired || peer.Status.Connected {
continue
}
_, duration := peer.SessionExpired(settings.PeerInactivityExpiration)
if nextExpiry == nil || duration < *nextExpiry {
// if expiration is below 1s return 1s duration
// this avoids issues with ticker that can't be set to < 0
if duration < time.Second {
return time.Second, true
}
nextExpiry = &duration
}
}
if nextExpiry == nil {
return 0, false
}
return *nextExpiry, true
}
// getExpiredPeers returns peers that have been expired.
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
var peers []*nbpeer.Peer
for _, peer := range peersWithExpiry {
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
if expired {
peers = append(peers, peer)
}
}
return peers, nil
}
// getInactivePeers returns peers that have been expired by inactivity
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
var peers []*nbpeer.Peer
for _, inactivePeer := range peersWithInactivity {
inactive, _ := inactivePeer.SessionExpired(settings.PeerInactivityExpiration)
if inactive {
peers = append(peers, inactivePeer)
}
}
return peers, nil
}
// GetPeerGroups returns groups that the peer is part of.
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peerGroups := make([]*nbgroup.Group, 0)
for _, group := range groups {
if slices.Contains(group.Peers, peerID) {
peerGroups = append(peerGroups, group)
}
}
return peerGroups, nil
}
// getPeerGroupIDs returns the IDs of the groups that the peer is part of.
func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID string, peerID string) ([]string, error) {
groups, err := am.GetPeerGroups(ctx, accountID, peerID)
if err != nil {
return nil, err
}
groupIDs := make([]string, 0, len(groups))
for _, group := range groups {
groupIDs = append(groupIDs, group.ID)
}
return groupIDs, err
}
func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID string) (lookupMap, error) {
dnsLabels, err := am.Store.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
existingLabels := make(lookupMap)
for _, label := range dnsLabels {
existingLabels[label] = struct{}{}
}
return existingLabels, nil
}
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, accountID, peerID string) (bool, error) {
peerGroupIDs, err := am.getPeerGroupIDs(ctx, accountID, peerID)
if err != nil {
return false, err
}
return am.areGroupChangesAffectPeers(ctx, accountID, peerGroupIDs)
}
// deletePeers deletes all specified peers and sends updates to the remote peers.
// Returns a slice of functions to save events after successful peer deletion.
func deletePeers(ctx context.Context, am *DefaultAccountManager, store Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
var peerDeletedEvents []func()
for _, peer := range peers {
if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil {
return nil, fmt.Errorf("failed to validate peer: %w", err)
}
network, err := store.GetAccountNetwork(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get account network: %w", err)
}
if err = store.DeletePeer(ctx, LockingStrengthUpdate, accountID, peer.ID); err != nil {
return nil, fmt.Errorf("failed to delete peer: %w", err)
}
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{
Update: &proto.SyncResponse{
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
NetworkMap: &proto.NetworkMap{
Serial: network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
},
},
NetworkMap: &NetworkMap{},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
peerDeletedEvents = append(peerDeletedEvents, func() {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
})
}
return peerDeletedEvents, nil
}
func ConvertSliceToMap(existingLabels []string) map[string]struct{} { func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
labelMap := make(map[string]struct{}, len(existingLabels)) labelMap := make(map[string]struct{}, len(existingLabels))
for _, label := range existingLabels { for _, label := range existingLabels {
@@ -1024,15 +1267,3 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
} }
return labelMap return labelMap
} }
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func isPeerInActiveGroup(account *Account, peerID string) bool {
peerGroupIDs := make([]string, 0)
for _, group := range account.Groups {
if slices.Contains(group.Peers, peerID) {
peerGroupIDs = append(peerGroupIDs, group.ID)
}
}
return areGroupChangesAffectPeers(account, peerGroupIDs)
}

View File

@@ -44,7 +44,7 @@ type Peer struct {
// CreatedAt records the time the peer was created // CreatedAt records the time the peer was created
CreatedAt time.Time CreatedAt time.Time
// Indicate ephemeral peer attribute // Indicate ephemeral peer attribute
Ephemeral bool Ephemeral bool `gorm:"index"`
// Geo location based on connection IP // Geo location based on connection IP
Location Location `gorm:"embedded;embeddedPrefix:location_"` Location Location `gorm:"embedded;embeddedPrefix:location_"`
} }

View File

@@ -467,21 +467,25 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
accountID := "test_account" accountID := "test_account"
adminUser := "account_creator" adminUser := "account_creator"
someUser := "some_user" someUser := "some_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "") err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
account.Users[someUser] = &User{ require.NoError(t, err, "failed to create account")
Id: someUser,
Role: UserRoleUser,
}
account.Settings.RegularUsersViewBlocked = false
err = manager.Store.SaveAccount(context.Background(), account) err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
if err != nil { Id: someUser,
t.Fatal(err) AccountID: accountID,
return Role: UserRoleUser,
} })
require.NoError(t, err, "failed to create user")
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account settings")
settings.RegularUsersViewBlocked = false
err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings)
require.NoError(t, err, "failed to save account settings")
// two peers one added by a regular user and one with a setup key // two peers one added by a regular user and one with a setup key
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
if err != nil { if err != nil {
t.Fatal("error creating setup key") t.Fatal("error creating setup key")
return return
@@ -535,7 +539,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
assert.NotNil(t, peer) assert.NotNil(t, peer)
// delete the all-to-all policy so that user's peer1 has no access to peer2 // delete the all-to-all policy so that user's peer1 has no access to peer2
for _, policy := range account.Policies { accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account policies")
for _, policy := range accountPolicies {
err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser) err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -563,7 +570,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
assert.NotNil(t, peer) assert.NotNil(t, peer)
} }
func TestDefaultAccountManager_GetPeers(t *testing.T) { func TestDefaultAccountManager_GetUserPeers(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
role UserRole role UserRole
@@ -654,21 +661,33 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
accountID := "test_account" accountID := "test_account"
adminUser := "account_creator" adminUser := "account_creator"
someUser := "some_user" someUser := "some_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account.Users[someUser] = &User{ err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
require.NoError(t, err, "failed to create account")
err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: someUser, Id: someUser,
AccountID: accountID,
Role: testCase.role, Role: testCase.role,
IsServiceUser: testCase.isServiceUser, IsServiceUser: testCase.isServiceUser,
} })
account.Policies = []*Policy{} require.NoError(t, err, "failed to create user")
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
err = manager.Store.SaveAccount(context.Background(), account) accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID)
if err != nil { require.NoError(t, err, "failed to get account policies")
t.Fatal(err)
return for _, policy := range accountPolicies {
err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser)
require.NoError(t, err, "failed to delete policy")
} }
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account settings")
settings.RegularUsersViewBlocked = testCase.limitedViewSettings
err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings)
require.NoError(t, err, "failed to save account settings")
peerKey1, err := wgtypes.GeneratePrivateKey() peerKey1, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -699,7 +718,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
return return
} }
peers, err := manager.GetPeers(context.Background(), accountID, someUser) peers, err := manager.GetUserPeers(context.Background(), accountID, someUser)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -724,10 +743,18 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
adminUser := "account_creator" adminUser := "account_creator"
regularUser := "regular_user" regularUser := "regular_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "") err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
account.Users[regularUser] = &User{ if err != nil {
Id: regularUser, return nil, "", "", err
Role: UserRoleUser, }
err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: regularUser,
AccountID: accountID,
Role: UserRoleUser,
})
if err != nil {
return nil, "", "", err
} }
// Create peers // Create peers
@@ -741,31 +768,40 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
Status: &nbpeer.PeerStatus{}, Status: &nbpeer.PeerStatus{},
UserID: regularUser, UserID: regularUser,
} }
account.Peers[peer.ID] = peer err = manager.Store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, peer)
if err != nil {
return nil, "", "", err
}
} }
// Create groups and policies // Create groups and policies
account.Policies = make([]*Policy, 0, groups)
for i := 0; i < groups; i++ { for i := 0; i < groups; i++ {
groupID := fmt.Sprintf("group-%d", i) groupID := fmt.Sprintf("group-%d", i)
group := &nbgroup.Group{ group := &nbgroup.Group{
ID: groupID, ID: groupID,
Name: fmt.Sprintf("Group %d", i), AccountID: accountID,
Name: fmt.Sprintf("Group %d", i),
} }
for j := 0; j < peers/groups; j++ { for j := 0; j < peers/groups; j++ {
peerIndex := i*(peers/groups) + j peerIndex := i*(peers/groups) + j
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
} }
account.Groups[groupID] = group
err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
if err != nil {
return nil, "", "", err
}
// Create a policy for this group // Create a policy for this group
policy := &Policy{ policy := &Policy{
ID: fmt.Sprintf("policy-%d", i), ID: fmt.Sprintf("policy-%d", i),
Name: fmt.Sprintf("Policy for Group %d", i), AccountID: accountID,
Enabled: true, Name: fmt.Sprintf("Policy for Group %d", i),
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: fmt.Sprintf("rule-%d", i), ID: fmt.Sprintf("rule-%d", i),
PolicyID: fmt.Sprintf("policy-%d", i),
Name: fmt.Sprintf("Rule for Group %d", i), Name: fmt.Sprintf("Rule for Group %d", i),
Enabled: true, Enabled: true,
Sources: []string{groupID}, Sources: []string{groupID},
@@ -776,22 +812,23 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
}, },
}, },
} }
account.Policies = append(account.Policies, policy)
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
if err != nil {
return nil, "", "", err
}
} }
account.PostureChecks = []*posture.Checks{ err = manager.Store.SavePostureChecks(context.Background(), LockingStrengthUpdate, &posture.Checks{
{ ID: "PostureChecksAll",
ID: "PostureChecksAll", AccountID: accountID,
Name: "All", Name: "All",
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1", MinVersion: "0.0.1",
},
}, },
}, },
} })
err = manager.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, "", "", err return nil, "", "", err
} }
@@ -824,9 +861,9 @@ func BenchmarkGetPeers(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := manager.GetPeers(context.Background(), accountID, userID) _, err := manager.GetUserPeers(context.Background(), accountID, userID)
if err != nil { if err != nil {
b.Fatalf("GetPeers failed: %v", err) b.Fatalf("GetUserPeers failed: %v", err)
} }
} }
}) })
@@ -876,7 +913,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
start := time.Now() start := time.Now()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.updateAccountPeers(ctx, account) manager.updateAccountPeers(ctx, accountID)
} }
duration := time.Since(start) duration := time.Since(start)
@@ -1401,10 +1438,13 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
// Adding peer to group linked with policy should update account peers and send peer update // Adding peer to group linked with policy should update account peers and send peer update
t.Run("adding peer to group linked with policy", func(t *testing.T) { t.Run("adding peer to group linked with policy", func(t *testing.T) {
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy", ID: "policy",
Enabled: true, AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: "rule",
PolicyID: "policy",
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupA"}, Destinations: []string{"groupA"},

View File

@@ -8,9 +8,8 @@ import (
"time" "time"
b "github.com/hashicorp/go-secure-stdlib/base62" b "github.com/hashicorp/go-secure-stdlib/base62"
"github.com/rs/xid"
"github.com/netbirdio/netbird/base62" "github.com/netbirdio/netbird/base62"
"github.com/rs/xid"
) )
const ( const (
@@ -58,7 +57,7 @@ type PersonalAccessTokenGenerated struct {
// CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User. // CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User.
// Additionally, it will return the token in plain text once, to give to the user and only save a hashed version // Additionally, it will return the token in plain text once, to give to the user and only save a hashed version
func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) { func CreateNewPAT(name string, expirationInDays int, targetUserID, createdBy string) (*PersonalAccessTokenGenerated, error) {
hashedToken, plainToken, err := generateNewToken() hashedToken, plainToken, err := generateNewToken()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -67,6 +66,7 @@ func CreateNewPAT(name string, expirationInDays int, createdBy string) (*Persona
return &PersonalAccessTokenGenerated{ return &PersonalAccessTokenGenerated{
PersonalAccessToken: PersonalAccessToken{ PersonalAccessToken: PersonalAccessToken{
ID: xid.New().String(), ID: xid.New().String(),
UserID: targetUserID,
Name: name, Name: name,
HashedToken: hashedToken, HashedToken: hashedToken,
ExpirationDate: currentTime.AddDate(0, 0, expirationInDays), ExpirationDate: currentTime.AddDate(0, 0, expirationInDays),

View File

@@ -3,13 +3,13 @@ package server
import ( import (
"context" "context"
_ "embed" _ "embed"
"slices" "fmt"
"strconv" "strconv"
"strings" "strings"
"github.com/netbirdio/netbird/management/proto"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -171,6 +171,7 @@ type Policy struct {
func (p *Policy) Copy() *Policy { func (p *Policy) Copy() *Policy {
c := &Policy{ c := &Policy{
ID: p.ID, ID: p.ID,
AccountID: p.AccountID,
Name: p.Name, Name: p.Name,
Description: p.Description, Description: p.Description,
Enabled: p.Enabled, Enabled: p.Enabled,
@@ -211,7 +212,6 @@ func (p *Policy) ruleGroups() []string {
groups = append(groups, rule.Sources...) groups = append(groups, rule.Sources...)
groups = append(groups, rule.Destinations...) groups = append(groups, rule.Destinations...)
} }
return groups return groups
} }
@@ -343,30 +343,73 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") return nil, status.NewUserNotPartOfAccountError()
} }
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
} }
// SavePolicy in the store // SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
updateAccountPeers, err := am.savePolicy(account, policy, isUpdate) if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
} }
account.Network.IncSerial() postureChecks, err := am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
if err = am.Store.SaveAccount(ctx, account); err != nil { if err != nil {
return err
}
for index, rule := range policy.Rules {
rule.Sources = getValidGroupIDs(groups, rule.Sources)
rule.Destinations = getValidGroupIDs(groups, rule.Destinations)
policy.Rules[index] = rule
}
if policy.SourcePostureChecks != nil {
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
}
updateAccountPeers, err := am.arePolicyChangesAffectPeers(ctx, policy, isUpdate)
if err != nil {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
saveFunc := transaction.SavePolicy
if !isUpdate {
saveFunc = transaction.CreatePolicy
}
if err := saveFunc(ctx, LockingStrengthUpdate, policy); err != nil {
return fmt.Errorf("failed to save policy: %w", err)
}
return nil
})
if err != nil {
return err return err
} }
@@ -377,7 +420,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers { if updateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
@@ -385,115 +428,91 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
// DeletePolicy from the store // DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
policy, err := am.deletePolicy(account, policyID) if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
policy, err := am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
if err != nil { if err != nil {
return err return err
} }
account.Network.IncSerial() updateAccountPeers, err := am.arePolicyChangesAffectPeers(ctx, policy, false)
if err = am.Store.SaveAccount(ctx, account); err != nil { if err != nil {
return err return err
} }
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if anyGroupHasPeers(account, policy.ruleGroups()) { if err = transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID); err != nil {
am.updateAccountPeers(ctx, account) return fmt.Errorf("failed to delete policy: %w", err)
}
return nil
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
} }
// ListPolicies from the store // ListPolicies from the store.
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
} }
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
} }
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { // arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
policyIdx := -1 func (am *DefaultAccountManager) arePolicyChangesAffectPeers(ctx context.Context, policy *Policy, isUpdate bool) (bool, error) {
for i, policy := range account.Policies {
if policy.ID == policyID {
policyIdx = i
break
}
}
if policyIdx < 0 {
return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID)
}
policy := account.Policies[policyIdx]
account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...)
return policy, nil
}
// savePolicy saves or updates a policy in the given account.
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
for index, rule := range policyToSave.Rules {
rule.Sources = filterValidGroupIDs(account, rule.Sources)
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
policyToSave.Rules[index] = rule
}
if policyToSave.SourcePostureChecks != nil {
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
}
if isUpdate { if isUpdate {
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) existingPolicy, err := am.Store.GetPolicyByID(ctx, LockingStrengthShare, policy.AccountID, policy.ID)
if policyIdx < 0 { if err != nil {
return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) return false, err
} }
oldPolicy := account.Policies[policyIdx] if !policy.Enabled && !existingPolicy.Enabled {
// Update the existing policy
account.Policies[policyIdx] = policyToSave
if !policyToSave.Enabled && !oldPolicy.Enabled {
return false, nil return false, nil
} }
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
return updateAccountPeers, nil hasPeers, err := am.anyGroupHasPeers(ctx, policy.AccountID, existingPolicy.ruleGroups())
} if err != nil {
return false, err
// Add the new policy to the account
account.Policies = append(account.Policies, policyToSave)
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
}
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.FirewallRule{
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
} }
if hasPeers {
return true, nil
}
return am.anyGroupHasPeers(ctx, policy.AccountID, policy.ruleGroups())
} }
return result
return am.anyGroupHasPeers(ctx, policy.AccountID, policy.ruleGroups())
} }
// getAllPeersFromGroups for given peer ID and list of groups // getAllPeersFromGroups for given peer ID and list of groups
@@ -574,27 +593,52 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
return nil return nil
} }
// filterValidPostureChecks filters and returns the posture check IDs from the given list // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
// that are valid within the provided account. func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds []string) []string {
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { validPostureCheckIDs := make(map[string]struct{})
result := make([]string, 0, len(postureChecksIds)) for _, check := range postureChecks {
validPostureCheckIDs[check.ID] = struct{}{}
}
validIDs := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds { for _, id := range postureChecksIds {
for _, postureCheck := range account.PostureChecks { if _, exists := validPostureCheckIDs[id]; exists {
if id == postureCheck.ID { validIDs = append(validIDs, id)
result = append(result, id)
continue
}
} }
} }
return result
return validIDs
} }
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. // getValidGroupIDs filters and returns only the valid group IDs from the provided list.
func filterValidGroupIDs(account *Account, groupIDs []string) []string { func getValidGroupIDs(groups []*nbgroup.Group, groupIDs []string) []string {
result := make([]string, 0, len(groupIDs)) validGroupIDs := make(map[string]struct{})
for _, groupID := range groupIDs { for _, group := range groups {
if _, exists := account.Groups[groupID]; exists { validGroupIDs[group.ID] = struct{}{}
result = append(result, groupID) }
validIDs := make([]string, 0, len(groupIDs))
for _, id := range groupIDs {
if _, exists := validGroupIDs[id]; exists {
validIDs = append(validIDs, id)
}
}
return validIDs
}
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.FirewallRule{
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
} }
} }
return result return result

View File

@@ -832,24 +832,28 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
{ {
ID: "groupA", ID: "groupA",
Name: "GroupA", AccountID: account.Id,
Peers: []string{peer1.ID, peer3.ID}, Name: "GroupA",
Peers: []string{peer1.ID, peer3.ID},
}, },
{ {
ID: "groupB", ID: "groupB",
Name: "GroupB", AccountID: account.Id,
Peers: []string{}, Name: "GroupB",
Peers: []string{},
}, },
{ {
ID: "groupC", ID: "groupC",
Name: "GroupC", AccountID: account.Id,
Peers: []string{}, Name: "GroupC",
Peers: []string{},
}, },
{ {
ID: "groupD", ID: "groupD",
Name: "GroupD", AccountID: account.Id,
Peers: []string{peer1.ID, peer2.ID}, Name: "GroupD",
Peers: []string{peer1.ID, peer2.ID},
}, },
}) })
assert.NoError(t, err) assert.NoError(t, err)
@@ -862,11 +866,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Saving policy with rule groups with no peers should not update account's peers and not send peer update // Saving policy with rule groups with no peers should not update account's peers and not send peer update
t.Run("saving policy with rule groups with no peers", func(t *testing.T) { t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
policy := Policy{ policy := Policy{
ID: "policy-rule-groups-no-peers", ID: "policy-rule-groups-no-peers",
Enabled: true, AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(), ID: xid.New().String(),
PolicyID: "policy-rule-groups-no-peers",
Enabled: true, Enabled: true,
Sources: []string{"groupB"}, Sources: []string{"groupB"},
Destinations: []string{"groupC"}, Destinations: []string{"groupC"},
@@ -896,11 +902,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// update account's peers and send peer update // update account's peers and send peer update
t.Run("saving policy where source has peers but destination does not", func(t *testing.T) { t.Run("saving policy where source has peers but destination does not", func(t *testing.T) {
policy := Policy{ policy := Policy{
ID: "policy-source-has-peers-destination-none", ID: "policy-source-has-peers-destination-none",
Enabled: true, AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(), ID: xid.New().String(),
PolicyID: "policy-source-has-peers-destination-none",
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupB"}, Destinations: []string{"groupB"},
@@ -931,11 +939,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// update account's peers and send peer update // update account's peers and send peer update
t.Run("saving policy where destination has peers but source does not", func(t *testing.T) { t.Run("saving policy where destination has peers but source does not", func(t *testing.T) {
policy := Policy{ policy := Policy{
ID: "policy-destination-has-peers-source-none", ID: "policy-destination-has-peers-source-none",
Enabled: true, AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(), ID: xid.New().String(),
PolicyID: "policy-destination-has-peers-source-none",
Enabled: false, Enabled: false,
Sources: []string{"groupC"}, Sources: []string{"groupC"},
Destinations: []string{"groupD"}, Destinations: []string{"groupD"},
@@ -966,11 +976,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// and send peer update // and send peer update
t.Run("saving policy with source and destination groups with peers", func(t *testing.T) { t.Run("saving policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{ policy := Policy{
ID: "policy-source-destination-peers", ID: "policy-source-destination-peers",
Enabled: true, AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(), ID: xid.New().String(),
PolicyID: "policy-source-destination-peers",
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupD"}, Destinations: []string{"groupD"},
@@ -1000,11 +1012,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// and send peer update // and send peer update
t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) { t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{ policy := Policy{
ID: "policy-source-destination-peers", ID: "policy-source-destination-peers",
Enabled: false, AccountID: account.Id,
Enabled: false,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(), ID: xid.New().String(),
PolicyID: "policy-source-destination-peers",
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupD"}, Destinations: []string{"groupD"},
@@ -1035,11 +1049,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) { t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{ policy := Policy{
ID: "policy-source-destination-peers", ID: "policy-source-destination-peers",
AccountID: account.Id,
Description: "updated description", Description: "updated description",
Enabled: false, Enabled: false,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(), ID: xid.New().String(),
PolicyID: "policy-source-destination-peers",
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupA"}, Destinations: []string{"groupA"},
@@ -1069,11 +1085,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// and send peer update // and send peer update
t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) { t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{ policy := Policy{
ID: "policy-source-destination-peers", ID: "policy-source-destination-peers",
Enabled: true, AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(), ID: xid.New().String(),
PolicyID: "policy-source-destination-peers",
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupD"}, Destinations: []string{"groupD"},

View File

@@ -2,16 +2,14 @@ package server
import ( import (
"context" "context"
"fmt"
"slices" "slices"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
const (
errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks"
) )
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
@@ -20,85 +18,127 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
return nil, err return nil, err
} }
if !user.HasAdminPower() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) return nil, status.NewUserNotPartOfAccountError()
}
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
}
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
} }
if !user.HasAdminPower() { if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) return nil, status.NewAdminPermissionError()
} }
if err := postureChecks.Validate(); err != nil { return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
}
// SavePostureChecks saves a posture check.
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return status.NewAdminPermissionError()
}
if err = am.validatePostureChecks(ctx, accountID, postureChecks); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) //nolint return status.Errorf(status.InvalidArgument, err.Error()) //nolint
} }
exists, uniqName := am.savePostureChecks(account, postureChecks) updateAccountPeers, err := am.arePostureCheckChangesAffectPeers(ctx, accountID, postureChecks.ID, isUpdate)
if err != nil {
// we do not allow create new posture checks with non uniq name return err
if !exists && !uniqName {
return status.Errorf(status.PreconditionFailed, "Posture check name should be unique")
} }
action := activity.PostureCheckCreated action := activity.PostureCheckCreated
if exists { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
action = activity.PostureCheckUpdated if isUpdate {
account.Network.IncSerial() action = activity.PostureCheckUpdated
}
if err = am.Store.SaveAccount(ctx, account); err != nil { if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
return fmt.Errorf("failed to get posture checks: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
}
if err = transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks); err != nil {
return fmt.Errorf("failed to save posture checks: %w", err)
}
return nil
})
if err != nil {
return err return err
} }
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { if updateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
} }
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { func (am *DefaultAccountManager) validatePostureChecks(ctx context.Context, accountID string, postureChecks *posture.Checks) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) if err := postureChecks.Validate(); err != nil {
defer unlock() return status.Errorf(status.InvalidArgument, err.Error()) //nolint
}
account, err := am.Store.GetAccount(ctx, accountID) checks, err := am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
} }
user, err := account.FindUser(userID) for _, check := range checks {
if check.Name == postureChecks.Name && check.ID != postureChecks.ID {
return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name)
}
}
return nil
}
// DeletePostureChecks deletes a posture check by ID.
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() { if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) return status.NewAdminPermissionError()
} }
postureChecks, err := am.deletePostureChecks(account, postureChecksID) postureChecks, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
if err != nil { if err != nil {
return err return err
} }
if err = am.Store.SaveAccount(ctx, account); err != nil { if err = am.isPostureCheckLinkedToPolicy(ctx, postureChecksID, accountID); err != nil {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if err = transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID); err != nil {
return fmt.Errorf("failed to delete posture checks: %w", err)
}
return nil
})
if err != nil {
return err return err
} }
@@ -107,132 +147,123 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
return nil return nil
} }
// ListPostureChecks returns a list of posture checks.
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !user.HasAdminPower() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) return nil, status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return nil, status.NewAdminPermissionError()
} }
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
} }
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) { // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
uniqName = true func (am *DefaultAccountManager) isPostureCheckLinkedToPolicy(ctx context.Context, postureChecksID, accountID string) error {
for i, p := range account.PostureChecks { policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if !exists && p.ID == postureChecks.ID { if err != nil {
account.PostureChecks[i] = postureChecks return err
exists = true }
}
if p.Name == postureChecks.Name { for _, policy := range policies {
uniqName = false if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
} }
} }
if !exists {
account.PostureChecks = append(account.PostureChecks, postureChecks)
}
return
}
func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) { return nil
postureChecksIdx := -1
for i, postureChecks := range account.PostureChecks {
if postureChecks.ID == postureChecksID {
postureChecksIdx = i
break
}
}
if postureChecksIdx < 0 {
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
}
// Check if posture check is linked to any policy
if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked {
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name)
}
postureChecks := account.PostureChecks[postureChecksIdx]
account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...)
return postureChecks, nil
} }
// getPeerPostureChecks returns the posture checks applied for a given peer. // getPeerPostureChecks returns the posture checks applied for a given peer.
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks { func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) {
peerPostureChecks := make(map[string]posture.Checks) postureChecks, err := am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
if err != nil || len(postureChecks) == 0 {
if len(account.PostureChecks) == 0 { return nil, err
return nil
} }
for _, policy := range account.Policies { policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peerPostureChecks := make(map[string]*posture.Checks)
for _, policy := range policies {
if !policy.Enabled { if !policy.Enabled {
continue continue
} }
if isPeerInPolicySourceGroups(peer.ID, account, policy) { isInGroup, err := am.isPeerInPolicySourceGroups(ctx, accountID, peerID, policy)
addPolicyPostureChecks(account, policy, peerPostureChecks) if err != nil {
return nil, err
}
if isInGroup {
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
postureCheck, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID)
if err == nil {
peerPostureChecks[sourcePostureCheckID] = postureCheck
}
}
} }
} }
postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks)) return maps.Values(peerPostureChecks), nil
for _, check := range peerPostureChecks {
checkCopy := check
postureChecksList = append(postureChecksList, &checkCopy)
}
return postureChecksList
} }
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. // isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool { func (am *DefaultAccountManager) isPeerInPolicySourceGroups(ctx context.Context, accountID, peerID string, policy *Policy) (bool, error) {
for _, rule := range policy.Rules { for _, rule := range policy.Rules {
if !rule.Enabled { if !rule.Enabled {
continue continue
} }
for _, sourceGroup := range rule.Sources { for _, sourceGroup := range rule.Sources {
group, ok := account.Groups[sourceGroup] group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup)
if ok && slices.Contains(group.Peers, peerID) { if err != nil {
return true log.WithContext(ctx).Debugf("failed to check peer in policy source group: %v", err)
return false, fmt.Errorf("failed to check peer in policy source group: %w", err)
}
if slices.Contains(group.Peers, peerID) {
return true, nil
} }
} }
} }
return false
}
func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) {
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
for _, postureCheck := range account.PostureChecks {
if postureCheck.ID == sourcePostureCheckID {
peerPostureChecks[sourcePostureCheckID] = *postureCheck
}
}
}
}
func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) {
for _, policy := range account.Policies {
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
return true, policy
}
}
return false, nil return false, nil
} }
// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers. // arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool { func (am *DefaultAccountManager) arePostureCheckChangesAffectPeers(ctx context.Context, accountID, postureCheckID string, exists bool) (bool, error) {
if !exists { if !exists {
return false return false, nil
} }
isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID) policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if !isLinked { if err != nil {
return false return false, err
} }
return anyGroupHasPeers(account, linkedPolicy.ruleGroups())
for _, policy := range policies {
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
hasPeers, err := am.anyGroupHasPeers(ctx, accountID, policy.ruleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
}
}
return false, nil
} }

View File

@@ -5,8 +5,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/rs/xid" "github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/group"
@@ -26,41 +27,43 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestPostureChecksAccount(am) accountID, err := initTestPostureChecksAccount(am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
t.Run("Generic posture check flow", func(t *testing.T) { t.Run("Generic posture check flow", func(t *testing.T) {
// regular users can not create checks // regular users can not create checks
err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) err := am.SavePostureChecks(context.Background(), accountID, regularUserID, &posture.Checks{}, false)
assert.Error(t, err) assert.Error(t, err)
// regular users cannot list check // regular users cannot list check
_, err = am.ListPostureChecks(context.Background(), account.Id, regularUserID) _, err = am.ListPostureChecks(context.Background(), accountID, regularUserID)
assert.Error(t, err) assert.Error(t, err)
// should be possible to create posture check with uniq name // should be possible to create posture check with uniq name
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{
ID: postureCheckID, ID: postureCheckID,
Name: postureCheckName, AccountID: accountID,
Name: postureCheckName,
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.26.0", MinVersion: "0.26.0",
}, },
}, },
}) }, false)
assert.NoError(t, err) assert.NoError(t, err)
// admin users can list check // admin users can list check
checks, err := am.ListPostureChecks(context.Background(), account.Id, adminUserID) checks, err := am.ListPostureChecks(context.Background(), accountID, adminUserID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, checks, 1) assert.Len(t, checks, 1)
// should not be possible to create posture check with non uniq name // should not be possible to create posture check with non uniq name
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{
ID: "new-id", ID: "new-id",
Name: postureCheckName, AccountID: accountID,
Name: postureCheckName,
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
GeoLocationCheck: &posture.GeoLocationCheck{ GeoLocationCheck: &posture.GeoLocationCheck{
Locations: []posture.Location{ Locations: []posture.Location{
@@ -70,57 +73,61 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
}, },
}, },
}, },
}) }, false)
assert.Error(t, err) assert.Error(t, err)
// admins can update posture checks // admins can update posture checks
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{
ID: postureCheckID, ID: postureCheckID,
Name: postureCheckName, AccountID: accountID,
Name: postureCheckName,
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.27.0", MinVersion: "0.27.0",
}, },
}, },
}) }, false)
assert.NoError(t, err) assert.NoError(t, err)
// users should not be able to delete posture checks // users should not be able to delete posture checks
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID) err = am.DeletePostureChecks(context.Background(), accountID, postureCheckID, regularUserID)
assert.Error(t, err) assert.Error(t, err)
// admin should be able to delete posture checks // admin should be able to delete posture checks
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID) err = am.DeletePostureChecks(context.Background(), accountID, postureCheckID, adminUserID)
assert.NoError(t, err) assert.NoError(t, err)
checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID) checks, err = am.ListPostureChecks(context.Background(), accountID, adminUserID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, checks, 0) assert.Len(t, checks, 0)
}) })
} }
func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { func initTestPostureChecksAccount(am *DefaultAccountManager) (string, error) {
accountID := "testingAccount" accountID := "testingAccount"
domain := "example.com" domain := "example.com"
admin := &User{ err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain)
Id: adminUserID,
Role: UserRoleAdmin,
}
user := &User{
Id: regularUserID,
Role: UserRoleUser,
}
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
account.Users[admin.Id] = admin
account.Users[user.Id] = user
err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return "", err
} }
return am.Store.GetAccount(context.Background(), account.Id) err = am.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
{
Id: adminUserID,
AccountID: accountID,
Role: UserRoleAdmin,
},
{
Id: regularUserID,
AccountID: accountID,
Role: UserRoleUser,
},
})
if err != nil {
return "", err
}
return accountID, nil
} }
func TestPostureCheckAccountPeersUpdate(t *testing.T) { func TestPostureCheckAccountPeersUpdate(t *testing.T) {
@@ -128,19 +135,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{
{ {
ID: "groupA", ID: "groupA",
Name: "GroupA", AccountID: account.Id,
Peers: []string{peer1.ID, peer2.ID, peer3.ID}, Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}, },
{ {
ID: "groupB", ID: "groupB",
Name: "GroupB", AccountID: account.Id,
Peers: []string{}, Name: "GroupB",
Peers: []string{},
}, },
{ {
ID: "groupC", ID: "groupC",
Name: "GroupC", AccountID: account.Id,
Peers: []string{}, Name: "GroupC",
Peers: []string{},
}, },
}) })
assert.NoError(t, err) assert.NoError(t, err)
@@ -169,7 +179,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, false)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -192,7 +202,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
MinVersion: "0.29.0", MinVersion: "0.29.0",
}, },
} }
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -203,11 +213,13 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}) })
policy := Policy{ policy := Policy{
ID: "policyA", ID: "policyA",
Enabled: true, AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(), ID: "ruleA",
PolicyID: "policyA",
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupA"}, Destinations: []string{"groupA"},
@@ -255,7 +267,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -303,17 +315,19 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
} }
}) })
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, false)
assert.NoError(t, err) assert.NoError(t, err)
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) { t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
policy = Policy{ policy = Policy{
ID: "policyB", ID: "policyB",
Enabled: true, AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(), ID: "ruleB",
PolicyID: "policyB",
Enabled: true, Enabled: true,
Sources: []string{"groupB"}, Sources: []string{"groupB"},
Destinations: []string{"groupC"}, Destinations: []string{"groupC"},
@@ -337,7 +351,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
MinVersion: "0.29.0", MinVersion: "0.29.0",
}, },
} }
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -355,11 +369,13 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
}) })
policy = Policy{ policy = Policy{
ID: "policyB", ID: "policyB",
Enabled: true, AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(), ID: "ruleB",
PolicyID: "policyB",
Enabled: true, Enabled: true,
Sources: []string{"groupB"}, Sources: []string{"groupB"},
Destinations: []string{"groupA"}, Destinations: []string{"groupA"},
@@ -384,7 +400,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
MinVersion: "0.29.0", MinVersion: "0.29.0",
}, },
} }
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -398,10 +414,13 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// should trigger account peers update and send peer update // should trigger account peers update and send peer update
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
policy = Policy{ policy = Policy{
ID: "policyB", ID: "policyB",
Enabled: true, AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: "ruleB",
PolicyID: "policyB",
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupB"}, Destinations: []string{"groupB"},
@@ -429,7 +448,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}, },
}, },
} }
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -441,79 +460,126 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
} }
func TestArePostureCheckChangesAffectingPeers(t *testing.T) { func TestArePostureCheckChangesAffectingPeers(t *testing.T) {
account := &Account{ manager, err := createManager(t)
Policies: []*Policy{ require.NoError(t, err, "failed to create account manager")
{
ID: "policyA", accountID, err := initTestPostureChecksAccount(manager)
Rules: []*PolicyRule{ require.NoError(t, err, "failed to init testing account")
{
Enabled: true, groupA := &group.Group{
Sources: []string{"groupA"}, ID: "groupA",
Destinations: []string{"groupA"}, AccountID: accountID,
}, Peers: []string{"peer1"},
},
SourcePostureChecks: []string{"checkA"},
},
},
Groups: map[string]*group.Group{
"groupA": {
ID: "groupA",
Peers: []string{"peer1"},
},
"groupB": {
ID: "groupB",
Peers: []string{},
},
},
PostureChecks: []*posture.Checks{
{
ID: "checkA",
},
{
ID: "checkB",
},
},
} }
groupB := &group.Group{
ID: "groupB",
AccountID: accountID,
Peers: []string{},
}
err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB})
require.NoError(t, err, "failed to save groups")
policy := &Policy{
ID: "policyA",
AccountID: accountID,
Rules: []*PolicyRule{
{
ID: "ruleA",
PolicyID: "policyA",
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
},
},
SourcePostureChecks: []string{"checkA"},
}
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err, "failed to save policy")
postureCheckA := &posture.Checks{
ID: "checkA",
Name: "checkA",
AccountID: accountID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
},
}
err = manager.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheckA, false)
require.NoError(t, err, "failed to save postureCheckA")
postureCheckB := &posture.Checks{
ID: "checkB",
Name: "checkB",
AccountID: accountID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
},
}
err = manager.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheckB, false)
require.NoError(t, err, "failed to save postureCheckB")
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
require.NoError(t, err)
assert.True(t, result) assert.True(t, result)
}) })
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) { t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "checkB", true) result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkB", true)
require.NoError(t, err)
assert.False(t, result) assert.False(t, result)
}) })
t.Run("posture check does not exist", func(t *testing.T) { t.Run("posture check does not exist", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "unknown", false) result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "unknown", false)
require.NoError(t, err)
assert.False(t, result) assert.False(t, result)
}) })
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"groupB"} policy.Rules[0].Sources = []string{"groupB"}
account.Policies[0].Rules[0].Destinations = []string{"groupA"} policy.Rules[0].Destinations = []string{"groupA"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err, "failed to update policy")
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
require.NoError(t, err)
assert.True(t, result) assert.True(t, result)
}) })
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"groupA"} policy.Rules[0].Sources = []string{"groupA"}
account.Policies[0].Rules[0].Destinations = []string{"groupB"} policy.Rules[0].Destinations = []string{"groupB"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err, "failed to update policy")
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
require.NoError(t, err)
assert.True(t, result) assert.True(t, result)
}) })
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"} groupA.Peers = []string{}
account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"} err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA)
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) require.NoError(t, err, "failed to save groups")
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
require.NoError(t, err)
assert.False(t, result) assert.False(t, result)
}) })
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
account.Groups["groupA"].Peers = []string{} policy.Rules[0].Sources = []string{"nonExistentGroup"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) policy.Rules[0].Destinations = []string{"nonExistentGroup"}
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err, "failed to update policy")
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, sErr.Type())
assert.False(t, result) assert.False(t, result)
}) })
} }

View File

@@ -52,17 +52,43 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") return nil, status.NewUserNotPartOfAccountError()
} }
return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID) if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID))
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func (am *DefaultAccountManager) GetRoutesByPrefixOrDomains(ctx context.Context, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
routes := make([]*route.Route, 0)
for _, r := range accountRoutes {
dynamic := r.IsDynamic()
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
!dynamic && r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
return routes, nil
} }
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, accountID, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
// routes can have both peer and peer_groups // routes can have both peer and peer_groups
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) routesWithPrefix, err := am.GetRoutesByPrefixOrDomains(ctx, accountID, prefix, domains)
if err != nil {
return err
}
// lets remember all the peers and the peer groups from routesWithPrefix // lets remember all the peers and the peer groups from routesWithPrefix
seenPeers := make(map[string]bool) seenPeers := make(map[string]bool)
@@ -81,8 +107,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
for _, groupID := range prefixRoute.PeerGroups { for _, groupID := range prefixRoute.PeerGroups {
seenPeerGroups[groupID] = true seenPeerGroups[groupID] = true
group := account.GetGroup(groupID) group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
if group == nil { if err != nil || group == nil {
return status.Errorf( return status.Errorf(
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist", status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
getRouteDescriptor(prefix, domains), groupID, getRouteDescriptor(prefix, domains), groupID,
@@ -97,10 +123,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
if peerID != "" { if peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group // check that peerID exists and is not in any route as single peer or part of the group
peer := account.GetPeer(peerID) peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID)
if peer == nil { if err != nil || peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
if _, ok := seenPeers[peerID]; ok { if _, ok := seenPeers[peerID]; ok {
return status.Errorf(status.AlreadyExists, return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID) "failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
@@ -109,7 +136,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
// check that peerGroupIDs are not in any route peerGroups list // check that peerGroupIDs are not in any route peerGroups list
for _, groupID := range peerGroupIDs { for _, groupID := range peerGroupIDs {
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again. // we validated the group existence before entering this function, no need to check again.
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
if err != nil || group == nil {
return status.Errorf(status.InvalidArgument, "group with ID %s not found", peerID)
}
if _, ok := seenPeerGroups[groupID]; ok { if _, ok := seenPeerGroups[groupID]; ok {
return status.Errorf( return status.Errorf(
@@ -120,10 +151,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
for _, id := range group.Peers { for _, id := range group.Peers {
if _, ok := seenPeers[id]; ok { if _, ok := seenPeers[id]; ok {
peer := account.GetPeer(id) peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID)
if peer == nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
return status.Errorf(status.AlreadyExists, return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s from the group %s already has this route", "failed to add route with %s - peer %s from the group %s already has this route",
getRouteDescriptor(prefix, domains), peer.Name, group.Name) getRouteDescriptor(prefix, domains), peer.Name, group.Name)
@@ -143,16 +175,22 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string {
// CreateRoute creates and saves a new route // CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
// Do not allow non-Linux peers // Do not allow non-Linux peers
if peer := account.GetPeer(peerID); peer != nil { if peerID != "" {
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
return nil, err
}
if peer.Meta.GoOS != "linux" { if peer.Meta.GoOS != "linux" {
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
} }
@@ -179,22 +217,28 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
var newRoute route.Route var newRoute route.Route
newRoute.ID = route.ID(xid.New().String()) newRoute.ID = route.ID(xid.New().String())
newRoute.AccountID = accountID
accountGroups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
if len(peerGroupIDs) > 0 { if len(peerGroupIDs) > 0 {
err = validateGroups(peerGroupIDs, account.Groups) err = validateGroups(peerGroupIDs, accountGroups)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
if len(accessControlGroupIDs) > 0 { if len(accessControlGroupIDs) > 0 {
err = validateGroups(accessControlGroupIDs, account.Groups) err = validateGroups(accessControlGroupIDs, accountGroups)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) err = am.checkRoutePrefixOrDomainsExistForPeers(ctx, accountID, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -207,7 +251,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
err = validateGroups(groups, account.Groups) err = validateGroups(groups, accountGroups)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -226,30 +270,46 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
newRoute.KeepRoute = keepRoute newRoute.KeepRoute = keepRoute
newRoute.AccessControlGroups = accessControlGroupIDs newRoute.AccessControlGroups = accessControlGroupIDs
if account.Routes == nil { updateAccountPeers, err := am.areRouteChangesAffectPeers(ctx, &newRoute)
account.Routes = make(map[route.ID]*route.Route) if err != nil {
}
account.Routes[newRoute.ID] = &newRoute
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err return nil, err
} }
if isRouteChangeAffectPeers(account, &newRoute) { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
am.updateAccountPeers(ctx, account) if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
err = transaction.SaveRoute(ctx, LockingStrengthUpdate, &newRoute)
if err != nil {
return fmt.Errorf("failed to create route: %w", err)
}
return nil
})
if err != nil {
return nil, err
} }
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return &newRoute, nil return &newRoute, nil
} }
// SaveRoute saves route // SaveRoute saves route
func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error { func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock() if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if routeToSave == nil { if routeToSave == nil {
return status.Errorf(status.InvalidArgument, "route provided is nil") return status.Errorf(status.InvalidArgument, "route provided is nil")
@@ -263,18 +323,11 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
account, err := am.Store.GetAccount(ctx, accountID) oldRoute, err := am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeToSave.ID))
if err != nil { if err != nil {
return err return err
} }
// Do not allow non-Linux peers
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
if peer.Meta.GoOS != "linux" {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
} }
@@ -291,72 +344,119 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time") return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
} }
// Do not allow non-Linux peers
if routeToSave.Peer != "" {
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, routeToSave.Peer)
if err != nil {
return err
}
if peer.Meta.GoOS != "linux" {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
if len(routeToSave.PeerGroups) > 0 { if len(routeToSave.PeerGroups) > 0 {
err = validateGroups(routeToSave.PeerGroups, account.Groups) err = validateGroups(routeToSave.PeerGroups, groups)
if err != nil { if err != nil {
return err return err
} }
} }
if len(routeToSave.AccessControlGroups) > 0 { if len(routeToSave.AccessControlGroups) > 0 {
err = validateGroups(routeToSave.AccessControlGroups, account.Groups) err = validateGroups(routeToSave.AccessControlGroups, groups)
if err != nil { if err != nil {
return err return err
} }
} }
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) err = am.checkRoutePrefixOrDomainsExistForPeers(ctx, accountID, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
if err != nil { if err != nil {
return err return err
} }
err = validateGroups(routeToSave.Groups, account.Groups) err = validateGroups(routeToSave.Groups, groups)
if err != nil { if err != nil {
return err return err
} }
oldRoute := account.Routes[routeToSave.ID] oldRouteAffectsPeers, err := am.areRouteChangesAffectPeers(ctx, oldRoute)
account.Routes[routeToSave.ID] = routeToSave if err != nil {
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { newRouteAffectsPeers, err := am.areRouteChangesAffectPeers(ctx, routeToSave)
am.updateAccountPeers(ctx, account) if err != nil {
return err
}
routeToSave.AccountID = accountID
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
err = transaction.SaveRoute(ctx, LockingStrengthUpdate, routeToSave)
if err != nil {
return fmt.Errorf("failed to save route: %w", err)
}
return nil
})
if err != nil {
return err
} }
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil return nil
} }
// DeleteRoute deletes route with routeID // DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
routy := account.Routes[routeID] if user.AccountID != accountID {
if routy == nil { return status.NewUserNotPartOfAccountError()
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
} }
delete(account.Routes, routeID)
account.Network.IncSerial() route, err := am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID))
if err = am.Store.SaveAccount(ctx, account); err != nil { if err != nil {
return err return err
} }
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) updateAccountPeers, err := am.areRouteChangesAffectPeers(ctx, route)
if err != nil {
return err
}
if isRouteChangeAffectPeers(account, routy) { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
am.updateAccountPeers(ctx, account) if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.DeleteRoute(ctx, LockingStrengthUpdate, accountID, string(routeID)); err != nil {
return fmt.Errorf("failed to delete route: %w", err)
}
return nil
})
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
@@ -369,8 +469,12 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
} }
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
@@ -649,8 +753,21 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
return &portInfo return &portInfo
} }
// isRouteChangeAffectPeers checks if a given route affects peers by determining // areRouteChangesAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers // if it has a routing peer, distribution, or peer groups that include peers.
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool { func (am *DefaultAccountManager) areRouteChangesAffectPeers(ctx context.Context, route *route.Route) (bool, error) {
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" if route.Peer != "" {
return true, nil
}
hasPeers, err := am.anyGroupHasPeers(ctx, route.AccountID, route.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return am.anyGroupHasPeers(ctx, route.AccountID, route.PeerGroups)
} }

View File

@@ -5,19 +5,20 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strings"
"testing" "testing"
"time" "time"
"github.com/rs/xid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/rs/xid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
const ( const (
@@ -427,21 +428,22 @@ func TestCreateRoute(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestRouteAccount(t, am) accountID, err := initTestRouteAccount(t, am)
if err != nil { if err != nil {
t.Errorf("failed to init testing account: %s", err) t.Errorf("failed to init testing account: %s", err)
} }
if testCase.createInitRoute { if testCase.createInitRoute {
groupAll, errInit := account.GetGroupAll() groupAll, errInit := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
require.NoError(t, errInit) require.NoError(t, errInit)
_, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
_, errInit = am.CreateRoute(context.Background(), accountID, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
require.NoError(t, errInit) require.NoError(t, errInit)
_, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false) _, errInit = am.CreateRoute(context.Background(), accountID, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false)
require.NoError(t, errInit) require.NoError(t, errInit)
} }
outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) outRoute, err := am.CreateRoute(context.Background(), accountID, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
testCase.errFunc(t, err) testCase.errFunc(t, err)
@@ -451,6 +453,7 @@ func TestCreateRoute(t *testing.T) {
// assign generated ID // assign generated ID
testCase.expectedRoute.ID = outRoute.ID testCase.expectedRoute.ID = outRoute.ID
testCase.expectedRoute.AccountID = accountID
if !testCase.expectedRoute.IsEqual(outRoute) { if !testCase.expectedRoute.IsEqual(outRoute) {
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", outRoute, testCase.expectedRoute) t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", outRoute, testCase.expectedRoute)
@@ -917,14 +920,15 @@ func TestSaveRoute(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestRouteAccount(t, am) accountID, err := initTestRouteAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
if testCase.createInitRoute { if testCase.createInitRoute {
account.Routes["initRoute"] = &route.Route{ initRoute := &route.Route{
ID: "initRoute", ID: "initRoute",
AccountID: accountID,
Network: existingNetwork, Network: existingNetwork,
NetID: existingRouteID, NetID: existingRouteID,
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
@@ -935,14 +939,13 @@ func TestSaveRoute(t *testing.T) {
Enabled: true, Enabled: true,
Groups: []string{routeGroup1}, Groups: []string{routeGroup1},
} }
err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, initRoute)
require.NoError(t, err, "failed to save init route")
} }
account.Routes[testCase.existingRoute.ID] = testCase.existingRoute testCase.existingRoute.AccountID = accountID
err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, testCase.existingRoute)
err = am.Store.SaveAccount(context.Background(), account) require.NoError(t, err, "failed to save existing route")
if err != nil {
t.Error("account should be saved")
}
var routeToSave *route.Route var routeToSave *route.Route
@@ -977,7 +980,7 @@ func TestSaveRoute(t *testing.T) {
} }
} }
err = am.SaveRoute(context.Background(), account.Id, userID, routeToSave) err = am.SaveRoute(context.Background(), accountID, userID, routeToSave)
testCase.errFunc(t, err) testCase.errFunc(t, err)
@@ -985,14 +988,10 @@ func TestSaveRoute(t *testing.T) {
return return
} }
account, err = am.Store.GetAccount(context.Background(), account.Id) savedRoute, err := am.GetRoute(context.Background(), accountID, testCase.existingRoute.ID, userID)
if err != nil { require.NoError(t, err, "failed to get saved route")
t.Fatal(err)
}
savedRoute, saved := account.Routes[testCase.expectedRoute.ID]
require.True(t, saved)
testCase.expectedRoute.AccountID = accountID
if !testCase.expectedRoute.IsEqual(savedRoute) { if !testCase.expectedRoute.IsEqual(savedRoute) {
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", savedRoute, testCase.expectedRoute) t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", savedRoute, testCase.expectedRoute)
} }
@@ -1001,50 +1000,48 @@ func TestSaveRoute(t *testing.T) {
} }
func TestDeleteRoute(t *testing.T) { func TestDeleteRoute(t *testing.T) {
testingRoute := &route.Route{
ID: "testingRoute",
Network: netip.MustParsePrefix("192.168.0.0/16"),
Domains: domain.List{"domain1", "domain2"},
KeepRoute: true,
NetworkType: route.IPv4Network,
Peer: peer1Key,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
}
am, err := createRouterManager(t) am, err := createRouterManager(t)
if err != nil { if err != nil {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestRouteAccount(t, am) accountID, err := initTestRouteAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
account.Routes[testingRoute.ID] = testingRoute err = am.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "GroupA",
AccountID: accountID,
Name: "GroupA",
})
require.NoError(t, err, "failed to save group")
err = am.Store.SaveAccount(context.Background(), account) testingRoute := &route.Route{
if err != nil { Network: netip.MustParsePrefix("192.168.0.0/16"),
t.Error("failed to save account") NetID: route.NetID("12345678901234567890qw"),
Groups: []string{"GroupA"},
KeepRoute: true,
NetworkType: route.IPv4Network,
Peer: peer1ID,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
} }
createdRoute, err := am.CreateRoute(context.Background(), accountID, testingRoute.Network, testingRoute.NetworkType, testingRoute.Domains, peer1ID, []string{}, testingRoute.Description, testingRoute.NetID, testingRoute.Masquerade, testingRoute.Metric, testingRoute.Groups, testingRoute.AccessControlGroups, true, userID, testingRoute.KeepRoute)
require.NoError(t, err, "failed to create route")
err = am.DeleteRoute(context.Background(), account.Id, testingRoute.ID, userID) err = am.DeleteRoute(context.Background(), accountID, createdRoute.ID, userID)
if err != nil { if err != nil {
t.Error("deleting route failed with error: ", err) t.Error("deleting route failed with error: ", err)
} }
savedAccount, err := am.Store.GetAccount(context.Background(), account.Id) _, err = am.GetRoute(context.Background(), accountID, testingRoute.ID, userID)
if err != nil { require.NotNil(t, err)
t.Error("failed to retrieve saved account with error: ", err) sErr, ok := status.FromError(err)
} require.True(t, ok)
require.Equal(t, status.NotFound, sErr.Type())
_, found := savedAccount.Routes[testingRoute.ID]
if found {
t.Error("route shouldn't be found after delete")
}
} }
func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
@@ -1066,16 +1063,14 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestRouteAccount(t, am) accountID, err := initTestRouteAccount(t, am)
if err != nil { require.NoError(t, err, "failed to init testing account")
t.Error("failed to init testing account")
}
newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute) newRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, newRoute.Enabled, true) require.Equal(t, newRoute.Enabled, true)
@@ -1091,7 +1086,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.ListGroups(context.Background(), account.Id) groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err) require.NoError(t, err)
var groupHA1, groupHA2 *nbgroup.Group var groupHA1, groupHA2 *nbgroup.Group
for _, group := range groups { for _, group := range groups {
@@ -1103,21 +1098,21 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
} }
} }
err = am.GroupDeletePeer(context.Background(), account.Id, groupHA1.ID, peer2ID) err = am.GroupDeletePeer(context.Background(), accountID, groupHA1.ID, peer2ID)
require.NoError(t, err) require.NoError(t, err)
peer2RoutesAfterDelete, err := am.GetNetworkMap(context.Background(), peer2ID) peer2RoutesAfterDelete, err := am.GetNetworkMap(context.Background(), peer2ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes") assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes")
err = am.GroupDeletePeer(context.Background(), account.Id, groupHA2.ID, peer4ID) err = am.GroupDeletePeer(context.Background(), accountID, groupHA2.ID, peer4ID)
require.NoError(t, err) require.NoError(t, err)
peer2RoutesAfterDelete, err = am.GetNetworkMap(context.Background(), peer2ID) peer2RoutesAfterDelete, err = am.GetNetworkMap(context.Background(), peer2ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route") assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route")
err = am.GroupAddPeer(context.Background(), account.Id, groupHA2.ID, peer4ID) err = am.GroupAddPeer(context.Background(), accountID, groupHA2.ID, peer4ID)
require.NoError(t, err) require.NoError(t, err)
peer1RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer1ID) peer1RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1128,7 +1123,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes") assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes")
err = am.DeleteRoute(context.Background(), account.Id, newRoute.ID, userID) err = am.DeleteRoute(context.Background(), accountID, newRoute.ID, userID)
require.NoError(t, err) require.NoError(t, err)
peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID) peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1158,16 +1153,14 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestRouteAccount(t, am) accountID, err := initTestRouteAccount(t, am)
if err != nil { require.NoError(t, err, "failed to init testing account")
t.Error("failed to init testing account")
}
newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute) createdRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute)
require.NoError(t, err) require.NoError(t, err)
noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1181,7 +1174,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
expectedRoute := enabledRoute.Copy() expectedRoute := enabledRoute.Copy()
expectedRoute.Peer = peer1Key expectedRoute.Peer = peer1Key
err = am.SaveRoute(context.Background(), account.Id, userID, enabledRoute) err = am.SaveRoute(context.Background(), accountID, userID, enabledRoute)
require.NoError(t, err) require.NoError(t, err)
peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID) peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1193,7 +1186,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, peer2Routes.Routes, 0, "no routes for peers not in the distribution group") require.Len(t, peer2Routes.Routes, 0, "no routes for peers not in the distribution group")
err = am.GroupAddPeer(context.Background(), account.Id, routeGroup1, peer2ID) err = am.GroupAddPeer(context.Background(), accountID, routeGroup1, peer2ID)
require.NoError(t, err) require.NoError(t, err)
peer2Routes, err = am.GetNetworkMap(context.Background(), peer2ID) peer2Routes, err = am.GetNetworkMap(context.Background(), peer2ID)
@@ -1202,27 +1195,29 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group") require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group")
newGroup := &nbgroup.Group{ newGroup := &nbgroup.Group{
ID: xid.New().String(), ID: xid.New().String(),
Name: "peer1 group", AccountID: accountID,
Peers: []string{peer1ID}, Name: "peer1 group",
Peers: []string{peer1ID},
} }
err = am.SaveGroup(context.Background(), account.Id, userID, newGroup) err = am.SaveGroup(context.Background(), accountID, userID, newGroup)
require.NoError(t, err) require.NoError(t, err)
rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser") policies, err := am.ListPolicies(context.Background(), accountID, "testingUser")
require.NoError(t, err) require.NoError(t, err)
defaultRule := rules[0] defaultRule := policies[0]
newPolicy := defaultRule.Copy() newPolicy := defaultRule.Copy()
newPolicy.ID = xid.New().String() newPolicy.ID = xid.New().String()
newPolicy.Name = "peer1 only" newPolicy.Name = "peer1 only"
newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Sources = []string{newGroup.ID}
newPolicy.Rules[0].Destinations = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID}
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false) err = am.SavePolicy(context.Background(), accountID, userID, newPolicy, false)
require.NoError(t, err) require.NoError(t, err)
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) err = am.DeletePolicy(context.Background(), accountID, defaultRule.ID, userID)
require.NoError(t, err) require.NoError(t, err)
peer1GroupRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) peer1GroupRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1233,7 +1228,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, peer2GroupRoutes.Routes, 0, "we should not receive routes for peer2") require.Len(t, peer2GroupRoutes.Routes, 0, "we should not receive routes for peer2")
err = am.DeleteRoute(context.Background(), account.Id, enabledRoute.ID, userID) err = am.DeleteRoute(context.Background(), accountID, enabledRoute.ID, userID)
require.NoError(t, err) require.NoError(t, err)
peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID) peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1267,179 +1262,104 @@ func createRouterStore(t *testing.T) (Store, error) {
return store, nil return store, nil
} }
func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
t.Helper() t.Helper()
accountID := "testingAcc" accountID := "testingAcc"
domain := "example.com" domain := "example.com"
account := newAccountWithId(context.Background(), accountID, userID, domain) err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return "", err
} }
ips := account.getTakenIPs() createPeer := func(peerID, peerKey, peerName, dnsLabel, kernel, core, platform, os string) (*nbpeer.Peer, error) {
peer1IP, err := AllocatePeerIP(account.Network.Net, ips) ips, err := am.Store.GetTakenIPs(context.Background(), LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
network, err := am.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peerIP, err := AllocatePeerIP(network.Net, ips)
if err != nil {
return nil, err
}
peer := &nbpeer.Peer{
IP: peerIP,
AccountID: accountID,
ID: peerID,
Key: peerKey,
Name: peerName,
DNSLabel: dnsLabel,
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: peerName,
GoOS: strings.ToLower(kernel),
Kernel: kernel,
Core: core,
Platform: platform,
OS: os,
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
if err := am.Store.AddPeerToAccount(context.Background(), peer); err != nil {
return nil, err
}
return peer, nil
}
// Create peers
peer1, err := createPeer(peer1ID, peer1Key, "test-host1@netbird.io", "test-host1", "Linux", "21.04", "x86_64", "Ubuntu")
if err != nil { if err != nil {
return nil, err return "", err
} }
peer2, err := createPeer(peer2ID, peer2Key, "test-host2@netbird.io", "test-host2", "Linux", "21.04", "x86_64", "Ubuntu")
peer1 := &nbpeer.Peer{
IP: peer1IP,
ID: peer1ID,
Key: peer1Key,
Name: "test-host1@netbird.io",
DNSLabel: "test-host1",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host1@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer1.ID] = peer1
ips = account.getTakenIPs()
peer2IP, err := AllocatePeerIP(account.Network.Net, ips)
if err != nil { if err != nil {
return nil, err return "", err
} }
peer3, err := createPeer(peer3ID, peer3Key, "test-host3@netbird.io", "test-host3", "Darwin", "13.4.1", "arm64", "darwin")
peer2 := &nbpeer.Peer{
IP: peer2IP,
ID: peer2ID,
Key: peer2Key,
Name: "test-host2@netbird.io",
DNSLabel: "test-host2",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host2@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer2.ID] = peer2
ips = account.getTakenIPs()
peer3IP, err := AllocatePeerIP(account.Network.Net, ips)
if err != nil { if err != nil {
return nil, err return "", err
} }
peer4, err := createPeer(peer4ID, peer4Key, "test-host4@netbird.io", "test-host4", "Linux", "21.04", "x86_64", "Ubuntu")
peer3 := &nbpeer.Peer{
IP: peer3IP,
ID: peer3ID,
Key: peer3Key,
Name: "test-host3@netbird.io",
DNSLabel: "test-host3",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host3@netbird.io",
GoOS: "darwin",
Kernel: "Darwin",
Core: "13.4.1",
Platform: "arm64",
OS: "darwin",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer3.ID] = peer3
ips = account.getTakenIPs()
peer4IP, err := AllocatePeerIP(account.Network.Net, ips)
if err != nil { if err != nil {
return nil, err return "", err
} }
peer5, err := createPeer(peer5ID, peer5Key, "test-host5@netbird.io", "test-host5", "Linux", "21.04", "x86_64", "Ubuntu")
peer4 := &nbpeer.Peer{
IP: peer4IP,
ID: peer4ID,
Key: peer4Key,
Name: "test-host4@netbird.io",
DNSLabel: "test-host4",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host4@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer4.ID] = peer4
ips = account.getTakenIPs()
peer5IP, err := AllocatePeerIP(account.Network.Net, ips)
if err != nil { if err != nil {
return nil, err return "", err
} }
peer5 := &nbpeer.Peer{ groupAll, err := am.GetGroupByName(context.Background(), "All", accountID)
IP: peer5IP, if err != nil {
ID: peer5ID, return "", err
Key: peer5Key,
Name: "test-host5@netbird.io",
DNSLabel: "test-host5",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host5@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
} }
account.Peers[peer5.ID] = peer5
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
}
groupAll, err := account.GetGroupAll()
if err != nil {
return nil, err
}
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer1ID) err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer1ID)
if err != nil { if err != nil {
return nil, err return "", err
} }
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer2ID) err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer2ID)
if err != nil { if err != nil {
return nil, err return "", err
} }
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer3ID) err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer3ID)
if err != nil { if err != nil {
return nil, err return "", err
} }
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer4ID) err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer4ID)
if err != nil { if err != nil {
return nil, err return "", err
} }
newGroup := []*nbgroup.Group{ newGroups := []*nbgroup.Group{
{ {
ID: routeGroup1, ID: routeGroup1,
Name: routeGroup1, Name: routeGroup1,
@@ -1471,15 +1391,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
Peers: []string{peer1.ID, peer4.ID}, Peers: []string{peer1.ID, peer4.ID},
}, },
} }
err = am.SaveGroups(context.Background(), accountID, userID, newGroups)
for _, group := range newGroup { if err != nil {
err = am.SaveGroup(context.Background(), accountID, userID, group) return "", err
if err != nil {
return nil, err
}
} }
return am.Store.GetAccount(context.Background(), account.Id) return accountID, nil
} }
func TestAccount_getPeersRoutesFirewall(t *testing.T) { func TestAccount_getPeersRoutesFirewall(t *testing.T) {
@@ -1783,10 +1700,10 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
manager, err := createRouterManager(t) manager, err := createRouterManager(t)
require.NoError(t, err, "failed to create account manager") require.NoError(t, err, "failed to create account manager")
account, err := initTestRouteAccount(t, manager) accountID, err := initTestRouteAccount(t, manager)
require.NoError(t, err, "failed to init testing account") require.NoError(t, err, "failed to init testing account")
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{
{ {
ID: "groupA", ID: "groupA",
Name: "GroupA", Name: "GroupA",
@@ -1832,7 +1749,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
}() }()
_, err := manager.CreateRoute( _, err := manager.CreateRoute(
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer,
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
route.Groups, []string{}, true, userID, route.KeepRoute, route.Groups, []string{}, true, userID, route.KeepRoute,
) )
@@ -1868,7 +1785,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
}() }()
_, err := manager.CreateRoute( _, err := manager.CreateRoute(
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer,
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
route.Groups, []string{}, true, userID, route.KeepRoute, route.Groups, []string{}, true, userID, route.KeepRoute,
) )
@@ -1904,7 +1821,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
}() }()
newRoute, err := manager.CreateRoute( newRoute, err := manager.CreateRoute(
context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer,
baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric,
baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute, baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute,
) )
@@ -1928,7 +1845,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute) err := manager.SaveRoute(context.Background(), accountID, userID, &baseRoute)
require.NoError(t, err) require.NoError(t, err)
select { select {
@@ -1946,7 +1863,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID) err := manager.DeleteRoute(context.Background(), accountID, baseRoute.ID, userID)
require.NoError(t, err) require.NoError(t, err)
select { select {
@@ -1970,7 +1887,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
Groups: []string{routeGroup1}, Groups: []string{routeGroup1},
} }
_, err := manager.CreateRoute( _, err := manager.CreateRoute(
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
) )
@@ -1982,7 +1899,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "groupB", ID: "groupB",
Name: "GroupB", Name: "GroupB",
Peers: []string{peer1ID}, Peers: []string{peer1ID},
@@ -2010,7 +1927,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
Groups: []string{"groupC"}, Groups: []string{"groupC"},
} }
_, err := manager.CreateRoute( _, err := manager.CreateRoute(
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
) )
@@ -2022,7 +1939,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "groupC", ID: "groupC",
Name: "GroupC", Name: "GroupC",
Peers: []string{peer1ID}, Peers: []string{peer1ID},

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
b64 "encoding/base64" b64 "encoding/base64"
"fmt"
"hash/fnv" "hash/fnv"
"strconv" "strconv"
"strings" "strings"
@@ -12,6 +11,7 @@ import (
"unicode/utf8" "unicode/utf8"
"github.com/google/uuid" "github.com/google/uuid"
nbgroup "github.com/netbirdio/netbird/management/server/group"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
@@ -226,34 +226,44 @@ func Hash(s string) uint32 {
// and adds it to the specified account. A list of autoGroups IDs can be empty. // and adds it to the specified account. A list of autoGroups IDs can be empty.
func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil { if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
if err = validateSetupKeyAutoGroups(groups, autoGroups); err != nil {
return nil, err return nil, err
} }
setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
account.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID
err = am.Store.SaveAccount(ctx, account)
if err != nil { if err = am.Store.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey); err != nil {
return nil, status.Errorf(status.Internal, "failed adding account key") return nil, err
} }
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta())
groupMap := make(map[string]*nbgroup.Group, len(groups))
for _, g := range groups {
groupMap[g.ID] = g
}
for _, g := range setupKey.AutoGroups { for _, g := range setupKey.AutoGroups {
group := account.GetGroup(g) group := groupMap[g]
if group != nil { if group != nil {
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey, am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name}) map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name})
} else { } else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, accountID)
} }
} }
@@ -268,30 +278,30 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
// (e.g. the key itself, creation date, ID, etc). // (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. // These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if keyToSave == nil { if keyToSave == nil {
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
} }
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var oldKey *SetupKey if user.AccountID != accountID {
for _, key := range account.SetupKeys { return nil, status.NewUserNotPartOfAccountError()
if key.Id == keyToSave.Id {
oldKey = key.Copy()
break
}
}
if oldKey == nil {
return nil, status.Errorf(status.NotFound, "setup key not found")
} }
if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil { groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
if err = validateSetupKeyAutoGroups(groups, keyToSave.AutoGroups); err != nil {
return nil, err
}
oldKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id)
if err != nil {
return nil, err return nil, err
} }
@@ -302,9 +312,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
newKey.Revoked = keyToSave.Revoked newKey.Revoked = keyToSave.Revoked
newKey.UpdatedAt = time.Now().UTC() newKey.UpdatedAt = time.Now().UTC()
account.SetupKeys[newKey.Key] = newKey if err = am.Store.SaveSetupKey(ctx, LockingStrengthUpdate, newKey); err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err return nil, err
} }
@@ -315,24 +323,30 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
defer func() { defer func() {
addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups)
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups)
groupMap := make(map[string]*nbgroup.Group, len(groups))
for _, g := range groups {
groupMap[g.ID] = g
}
for _, g := range removedGroups { for _, g := range removedGroups {
group := account.GetGroup(g) group := groupMap[g]
if group != nil { if group != nil {
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey, am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
} else { } else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, accountID)
} }
} }
for _, g := range addedGroups { for _, g := range addedGroups {
group := account.GetGroup(g) group := groupMap[g]
if group != nil { if group != nil {
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey, am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
} else { } else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, accountID)
} }
} }
}() }()
@@ -347,8 +361,12 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.NewUnauthorizedToViewSetupKeysError() return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
} }
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
@@ -366,8 +384,12 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.NewUnauthorizedToViewSetupKeysError() return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
} }
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
@@ -387,21 +409,25 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user: %w", err) return err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return status.NewUnauthorizedToViewSetupKeysError() return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
} }
deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get setup key: %w", err) return err
} }
err = am.Store.DeleteSetupKey(ctx, accountID, keyID) err = am.Store.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete setup key: %w", err) return err
} }
am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta()) am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta())
@@ -409,15 +435,22 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
return nil return nil
} }
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error { func validateSetupKeyAutoGroups(groups []*nbgroup.Group, autoGroups []string) error {
for _, group := range autoGroups { groupMap := make(map[string]*nbgroup.Group, len(groups))
g, ok := account.Groups[group] for _, g := range groups {
if !ok { groupMap[g.ID] = g
return status.Errorf(status.NotFound, "group %s doesn't exist", group) }
for _, groupID := range autoGroups {
g, exists := groupMap[groupID]
if !exists {
return status.Errorf(status.NotFound, "group %s doesn't exist", groupID)
} }
if g.Name == "All" { if g.Name == "All" {
return status.Errorf(status.InvalidArgument, "can't add All group to the setup key") return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key")
} }
} }
return nil return nil
} }

View File

@@ -25,21 +25,21 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
} }
userID := "testingUser" userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
if err != nil { require.NoError(t, err, "failed to get or create account ID")
t.Fatal(err)
}
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{
{ {
ID: "group_1", ID: "group_1",
Name: "group_name_1", AccountID: accountID,
Peers: []string{}, Name: "group_name_1",
Peers: []string{},
}, },
{ {
ID: "group_2", ID: "group_2",
Name: "group_name_2", AccountID: accountID,
Peers: []string{}, Name: "group_name_2",
Peers: []string{},
}, },
}) })
if err != nil { if err != nil {
@@ -49,7 +49,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
expiresIn := time.Hour expiresIn := time.Hour
keyName := "my-test-key" keyName := "my-test-key"
key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{}, key, err := manager.CreateSetupKey(context.Background(), accountID, keyName, SetupKeyReusable, expiresIn, []string{},
SetupKeyUnlimitedUsage, userID, false) SetupKeyUnlimitedUsage, userID, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -58,7 +58,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
autoGroups := []string{"group_1", "group_2"} autoGroups := []string{"group_1", "group_2"}
newKeyName := "my-new-test-key" newKeyName := "my-new-test-key"
revoked := true revoked := true
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ newKey, err := manager.SaveSetupKey(context.Background(), accountID, &SetupKey{
Id: key.Id, Id: key.Id,
Name: newKeyName, Name: newKeyName,
Revoked: revoked, Revoked: revoked,
@@ -72,22 +72,22 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
key.Id, time.Now().UTC(), autoGroups, true) key.Id, time.Now().UTC(), autoGroups, true)
// check the corresponding events that should have been generated // check the corresponding events that should have been generated
ev := getEvent(t, account.Id, manager, activity.SetupKeyRevoked) ev := getEvent(t, accountID, manager, activity.SetupKeyRevoked)
assert.NotNil(t, ev) assert.NotNil(t, ev)
assert.Equal(t, account.Id, ev.AccountID) assert.Equal(t, accountID, ev.AccountID)
assert.Equal(t, newKeyName, ev.Meta["name"]) assert.Equal(t, newKeyName, ev.Meta["name"])
assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"])) assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"]))
assert.NotEmpty(t, ev.Meta["key"]) assert.NotEmpty(t, ev.Meta["key"])
assert.Equal(t, userID, ev.InitiatorID) assert.Equal(t, userID, ev.InitiatorID)
assert.Equal(t, key.Id, ev.TargetID) assert.Equal(t, key.Id, ev.TargetID)
groupAll, err := account.GetGroupAll() groupAll, err := manager.GetGroupByName(context.Background(), "All", accountID)
assert.NoError(t, err) require.NoError(t, err)
// saving setup key with All group assigned to auto groups should return error // saving setup key with All group assigned to auto groups should return error
autoGroups = append(autoGroups, groupAll.ID) autoGroups = append(autoGroups, groupAll.ID)
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ _, err = manager.SaveSetupKey(context.Background(), accountID, &SetupKey{
Id: key.Id, Id: key.Id,
Name: newKeyName, Name: newKeyName,
Revoked: revoked, Revoked: revoked,
@@ -103,31 +103,31 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
} }
userID := "testingUser" userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
if err != nil { require.NoError(t, err, "failed to get or create account ID")
t.Fatal(err)
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "group_1", ID: "group_1",
Name: "group_name_1", AccountID: accountID,
Peers: []string{}, Name: "group_name_1",
Peers: []string{},
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "group_2", ID: "group_2",
Name: "group_name_2", AccountID: accountID,
Peers: []string{}, Name: "group_name_2",
Peers: []string{},
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
groupAll, err := account.GetGroupAll() groupAll, err := manager.GetGroupByName(context.Background(), "All", accountID)
assert.NoError(t, err) require.NoError(t, err)
type testCase struct { type testCase struct {
name string name string
@@ -170,7 +170,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
for _, tCase := range []testCase{testCase1, testCase2, testCase3} { for _, tCase := range []testCase{testCase1, testCase2, testCase3} {
t.Run(tCase.name, func(t *testing.T) { t.Run(tCase.name, func(t *testing.T) {
key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, key, err := manager.CreateSetupKey(context.Background(), accountID, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false) tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false)
if tCase.expectedFailure { if tCase.expectedFailure {
@@ -189,10 +189,10 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
tCase.expectedUpdatedAt, tCase.expectedGroups, false) tCase.expectedUpdatedAt, tCase.expectedGroups, false)
// check the corresponding events that should have been generated // check the corresponding events that should have been generated
ev := getEvent(t, account.Id, manager, activity.SetupKeyCreated) ev := getEvent(t, accountID, manager, activity.SetupKeyCreated)
assert.NotNil(t, ev) assert.NotNil(t, ev)
assert.Equal(t, account.Id, ev.AccountID) assert.Equal(t, accountID, ev.AccountID)
assert.Equal(t, tCase.expectedKeyName, ev.Meta["name"]) assert.Equal(t, tCase.expectedKeyName, ev.Meta["name"])
assert.Equal(t, tCase.expectedType, fmt.Sprint(ev.Meta["type"])) assert.Equal(t, tCase.expectedType, fmt.Sprint(ev.Meta["type"]))
assert.NotEmpty(t, ev.Meta["key"]) assert.NotEmpty(t, ev.Meta["key"])
@@ -208,12 +208,10 @@ func TestGetSetupKeys(t *testing.T) {
} }
userID := "testingUser" userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
if err != nil { require.NoError(t, err, "failed to get or create account ID")
t.Fatal(err)
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "group_1", ID: "group_1",
Name: "group_name_1", Name: "group_name_1",
Peers: []string{}, Peers: []string{},
@@ -222,7 +220,7 @@ func TestGetSetupKeys(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "group_2", ID: "group_2",
Name: "group_name_2", Name: "group_name_2",
Peers: []string{}, Peers: []string{},
@@ -384,20 +382,24 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupA", ID: "groupA",
Name: "GroupA", AccountID: account.Id,
Peers: []string{peer1.ID, peer2.ID, peer3.ID}, Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}) })
assert.NoError(t, err) assert.NoError(t, err)
policy := Policy{ policy := Policy{
ID: "policy", ID: "policy",
Enabled: true, AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: "Rule",
PolicyID: "policy",
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"group"}, Destinations: []string{"groupA"},
Bidirectional: true, Bidirectional: true,
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },

File diff suppressed because it is too large Load Diff

View File

@@ -68,20 +68,27 @@ func TestSqlite_SaveAccount_Large(t *testing.T) {
func runLargeTest(t *testing.T, store Store) { func runLargeTest(t *testing.T, store Store) {
t.Helper() t.Helper()
account := newAccountWithId(context.Background(), "account_id", "testuser", "") accountID := "account_id"
groupALL, err := account.GetGroupAll()
if err != nil { err := newAccountWithId(context.Background(), store, accountID, "testuser", "")
t.Fatal(err) assert.NoError(t, err, "failed to create account")
}
groupAll, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
assert.NoError(t, err, "failed to get group All")
setupKey, _ := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
assert.NoError(t, err, "failed to save setup key")
const numPerAccount = 6000 const numPerAccount = 6000
for n := 0; n < numPerAccount; n++ { for n := 0; n < numPerAccount; n++ {
netIP := randomIPv4() netIP := randomIPv4()
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n) peerID := fmt.Sprintf("%s-peer-%d", accountID, n)
peer := &nbpeer.Peer{ peer := &nbpeer.Peer{
ID: peerID, ID: peerID,
AccountID: accountID,
Key: peerID, Key: peerID,
IP: netIP, IP: netIP,
Name: peerID, Name: peerID,
@@ -90,16 +97,21 @@ func runLargeTest(t *testing.T, store Store) {
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false, SSHEnabled: false,
} }
account.Peers[peerID] = peer err = store.AddPeerToAccount(context.Background(), peer)
group, _ := account.GetGroupAll() assert.NoError(t, err, "failed to add peer")
group.Peers = append(group.Peers, peerID)
user := &User{ err = store.AddPeerToAllGroup(context.Background(), accountID, peerID)
Id: fmt.Sprintf("%s-user-%d", account.Id, n), assert.NoError(t, err, "failed to add peer to all group")
AccountID: account.Id,
} err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
account.Users[user.Id] = user Id: fmt.Sprintf("%s-user-%d", accountID, n),
AccountID: accountID,
})
assert.NoError(t, err, "failed to save user")
route := &route2.Route{ route := &route2.Route{
ID: route2.ID(fmt.Sprintf("network-id-%d", n)), ID: route2.ID(fmt.Sprintf("network-id-%d", n)),
AccountID: accountID,
Description: "base route", Description: "base route",
NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)), NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)),
Network: netip.MustParsePrefix(netIP.String() + "/24"), Network: netip.MustParsePrefix(netIP.String() + "/24"),
@@ -107,22 +119,24 @@ func runLargeTest(t *testing.T, store Store) {
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
Enabled: true, Enabled: true,
Groups: []string{groupALL.ID}, Groups: []string{groupAll.ID},
} }
account.Routes[route.ID] = route err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route)
assert.NoError(t, err, "failed to save route")
group = &nbgroup.Group{ group := &nbgroup.Group{
ID: fmt.Sprintf("group-id-%d", n), ID: fmt.Sprintf("group-id-%d", n),
AccountID: account.Id, AccountID: accountID,
Name: fmt.Sprintf("group-id-%d", n), Name: fmt.Sprintf("group-id-%d", n),
Issued: "api", Issued: "api",
Peers: nil, Peers: nil,
} }
account.Groups[group.ID] = group err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
assert.NoError(t, err, "failed to save group")
nameserver := &nbdns.NameServerGroup{ nameserver := &nbdns.NameServerGroup{
ID: fmt.Sprintf("nameserver-id-%d", n), ID: fmt.Sprintf("nameserver-id-%d", n),
AccountID: account.Id, AccountID: accountID,
Name: fmt.Sprintf("nameserver-id-%d", n), Name: fmt.Sprintf("nameserver-id-%d", n),
Description: "", Description: "",
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}}, NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
@@ -132,20 +146,20 @@ func runLargeTest(t *testing.T, store Store) {
Enabled: false, Enabled: false,
SearchDomainsEnabled: false, SearchDomainsEnabled: false,
} }
account.NameServerGroups[nameserver.ID] = nameserver err = store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nameserver)
assert.NoError(t, err, "failed to save nameserver group")
setupKey, _ := GenerateDefaultSetupKey() setupKey, _ = GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
assert.NoError(t, err, "failed to save setup key")
} }
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 1 { if len(store.GetAllAccounts(context.Background())) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
} }
a, err := store.GetAccount(context.Background(), account.Id) a, err := store.GetAccount(context.Background(), accountID)
if a == nil { if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
} }
@@ -213,41 +227,53 @@ func TestSqlite_SaveAccount(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
assert.NoError(t, err) require.NoError(t, err)
accountID := "account_id"
err = newAccountWithId(context.Background(), store, accountID, "testuser", "")
require.NoError(t, err, "failed to create account")
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
setupKey, _ := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID
account.Peers["testpeer"] = &nbpeer.Peer{ err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
Key: "peerkey", require.NoError(t, err, "failed to save setup key")
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
err = store.SaveAccount(context.Background(), account) err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
require.NoError(t, err) ID: "testpeer",
Key: "peerkey",
IP: net.IP{127, 0, 0, 1},
AccountID: accountID,
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
})
require.NoError(t, err, "failed to save peer")
accountID2 := "account_id2"
err = newAccountWithId(context.Background(), store, accountID2, "testuser2", "")
require.NoError(t, err, "failed to create account")
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
setupKey, _ = GenerateDefaultSetupKey() setupKey, _ = GenerateDefaultSetupKey()
account2.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID2
account2.Peers["testpeer2"] = &nbpeer.Peer{ err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
Key: "peerkey2", require.NoError(t, err, "failed to save setup key")
IP: net.IP{127, 0, 0, 2},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
err = store.SaveAccount(context.Background(), account2) err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
require.NoError(t, err) ID: "testpeer2",
Key: "peerkey2",
AccountID: accountID2,
IP: net.IP{127, 0, 0, 2},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
})
require.NoError(t, err, "failed to save peer")
if len(store.GetAllAccounts(context.Background())) != 2 { if len(store.GetAllAccounts(context.Background())) != 2 {
t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") t.Errorf("expecting 2 Accounts to be stored after SaveAccount()")
} }
a, err := store.GetAccount(context.Background(), account.Id) a, err := store.GetAccount(context.Background(), accountID)
if a == nil { if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
} }
@@ -288,36 +314,52 @@ func TestSqlite_DeleteAccount(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
assert.NoError(t, err) assert.NoError(t, err)
accountID := "account_id"
testUserID := "testuser" testUserID := "testuser"
user := NewAdminUser(testUserID)
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
}}
account := newAccountWithId(context.Background(), "account_id", testUserID, "") err = newAccountWithId(context.Background(), store, accountID, testUserID, "")
setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
account.Users[testUserID] = user
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err) require.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 1 { setupKey, _ := GenerateDefaultSetupKey()
setupKey.AccountID = accountID
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
require.NoError(t, err, "failed to save setup key")
err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
ID: "testpeer",
Key: "peerkey",
AccountID: accountID,
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
})
require.NoError(t, err, "failed to save peer")
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
ID: "testtoken",
UserID: testUserID,
Name: "test token",
})
require.NoError(t, err, "failed to save personal access token")
accountIDs, err := store.GetAllAccountIDs(context.Background(), LockingStrengthShare)
require.NoError(t, err, "failed to get all account ids")
if len(accountIDs) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
} }
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "failed to get account")
err = store.DeleteAccount(context.Background(), account) err = store.DeleteAccount(context.Background(), account)
require.NoError(t, err) require.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 0 { accountIDs, err = store.GetAllAccountIDs(context.Background(), LockingStrengthShare)
require.NoError(t, err, "failed to get all account ids after DeleteAccount()")
if len(accountIDs) != 0 {
t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()") t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()")
} }
@@ -400,7 +442,7 @@ func TestSqlite_SavePeer(t *testing.T) {
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} }
ctx := context.Background() ctx := context.Background()
err = store.SavePeer(ctx, account.Id, peer) err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, peer)
assert.Error(t, err) assert.Error(t, err)
parsedErr, ok := status.FromError(err) parsedErr, ok := status.FromError(err)
require.True(t, ok) require.True(t, ok)
@@ -416,7 +458,7 @@ func TestSqlite_SavePeer(t *testing.T) {
updatedPeer.Status.Connected = false updatedPeer.Status.Connected = false
updatedPeer.Meta.Hostname = "updatedpeer" updatedPeer.Meta.Hostname = "updatedpeer"
err = store.SavePeer(ctx, account.Id, updatedPeer) err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, updatedPeer)
require.NoError(t, err) require.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id) account, err = store.GetAccount(context.Background(), account.Id)
@@ -442,7 +484,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
// save status of non-existing peer // save status of non-existing peer
newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()} newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus)
assert.Error(t, err) assert.Error(t, err)
parsedErr, ok := status.FromError(err) parsedErr, ok := status.FromError(err)
require.True(t, ok) require.True(t, ok)
@@ -461,7 +503,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
err = store.SaveAccount(context.Background(), account) err = store.SaveAccount(context.Background(), account)
require.NoError(t, err) require.NoError(t, err)
err = store.SavePeerStatus(account.Id, "testpeer", newStatus) err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus)
require.NoError(t, err) require.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id) account, err = store.GetAccount(context.Background(), account.Id)
@@ -472,7 +514,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
newStatus.Connected = true newStatus.Connected = true
err = store.SavePeerStatus(account.Id, "testpeer", newStatus) err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus)
require.NoError(t, err) require.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id) account, err = store.GetAccount(context.Background(), account.Id)
@@ -507,7 +549,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) {
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
} }
// error is expected as peer is not in store yet // error is expected as peer is not in store yet
err = store.SavePeerLocation(account.Id, peer) err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer)
assert.Error(t, err) assert.Error(t, err)
account.Peers[peer.ID] = peer account.Peers[peer.ID] = peer
@@ -519,7 +561,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) {
peer.Location.CityName = "Berlin" peer.Location.CityName = "Berlin"
peer.Location.GeoNameID = 2950159 peer.Location.GeoNameID = 2950159
err = store.SavePeerLocation(account.Id, account.Peers[peer.ID]) err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, account.Peers[peer.ID])
assert.NoError(t, err) assert.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id) account, err = store.GetAccount(context.Background(), account.Id)
@@ -529,7 +571,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) {
assert.Equal(t, peer.Location, actual) assert.Equal(t, peer.Location, actual)
peer.ID = "non-existing-peer" peer.ID = "non-existing-peer"
err = store.SavePeerLocation(account.Id, peer) err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer)
assert.Error(t, err) assert.Error(t, err)
parsedErr, ok := status.FromError(err) parsedErr, ok := status.FromError(err)
require.True(t, ok) require.True(t, ok)
@@ -572,11 +614,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
hashed := "SoMeHaShEdToKeN" hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003" id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed) pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, hashed)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, id, token) require.Equal(t, id, pat.ID)
_, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash") _, err = store.GetPATByHashedToken(context.Background(), LockingStrengthShare, "non-existing-hash")
require.Error(t, err) require.Error(t, err)
parsedErr, ok := status.FromError(err) parsedErr, ok := status.FromError(err)
require.True(t, ok) require.True(t, ok)
@@ -595,11 +637,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
id := "9dj38s35-63fb-11ec-90d6-0242ac120003" id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(context.Background(), id) user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID) require.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id)
_, err = store.GetUserByTokenID(context.Background(), "non-existing-id") _, err = store.GetUserByPATID(context.Background(), LockingStrengthShare, "non-existing-id")
require.Error(t, err) require.Error(t, err)
parsedErr, ok := status.FromError(err) parsedErr, ok := status.FromError(err)
require.True(t, ok) require.True(t, ok)
@@ -714,19 +756,28 @@ func newSqliteStore(t *testing.T) *SqlStore {
} }
func newAccount(store Store, id int) error { func newAccount(store Store, id int) error {
str := fmt.Sprintf("%s-%d", uuid.New().String(), id) accountID := fmt.Sprintf("%s-%d", uuid.New().String(), id)
account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com") userID := accountID + "-testuser"
err := newAccountWithId(context.Background(), store, accountID, userID, "example.com")
if err != nil {
return err
}
setupKey, _ := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID
account.Peers["p"+str] = &nbpeer.Peer{ err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
Key: "peerkey" + str, if err != nil {
return err
}
return store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{
Key: accountID + "-peerkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} })
return store.SaveAccount(context.Background(), account)
} }
func TestPostgresql_NewStore(t *testing.T) { func TestPostgresql_NewStore(t *testing.T) {
@@ -754,39 +805,56 @@ func TestPostgresql_SaveAccount(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
assert.NoError(t, err) assert.NoError(t, err)
account := newAccountWithId(context.Background(), "account_id", "testuser", "") accountID := "account_id"
err = newAccountWithId(context.Background(), store, accountID, "testuser", "")
require.NoError(t, err, "failed to create account")
setupKey, _ := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID
account.Peers["testpeer"] = &nbpeer.Peer{ err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
Key: "peerkey", require.NoError(t, err, "failed to save setup key")
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
err = store.SaveAccount(context.Background(), account) err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
require.NoError(t, err) ID: "testpeer",
Key: "peerkey",
IP: net.IP{127, 0, 0, 1},
AccountID: accountID,
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
})
require.NoError(t, err, "failed to save peer")
accountID2 := "account_id2"
err = newAccountWithId(context.Background(), store, accountID2, "testuser2", "")
require.NoError(t, err, "failed to create account")
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
setupKey, _ = GenerateDefaultSetupKey() setupKey, _ = GenerateDefaultSetupKey()
account2.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID2
account2.Peers["testpeer2"] = &nbpeer.Peer{ err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
Key: "peerkey2", require.NoError(t, err, "failed to save setup key")
IP: net.IP{127, 0, 0, 2},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
err = store.SaveAccount(context.Background(), account2) err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
require.NoError(t, err) ID: "testpeer2",
Key: "peerkey2",
AccountID: accountID2,
IP: net.IP{127, 0, 0, 2},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
})
require.NoError(t, err, "failed to save peer")
if len(store.GetAllAccounts(context.Background())) != 2 { accountIDs, err := store.GetAllAccountIDs(context.Background(), LockingStrengthUpdate)
require.NoError(t, err, "failed to get all account ids")
if len(accountIDs) != 2 {
t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") t.Errorf("expecting 2 Accounts to be stored after SaveAccount()")
} }
a, err := store.GetAccount(context.Background(), account.Id) a, err := store.GetAccount(context.Background(), accountID)
if a == nil { if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
} }
@@ -827,32 +895,51 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
assert.NoError(t, err) assert.NoError(t, err)
accountID := "account_id"
testUserID := "testuser" testUserID := "testuser"
user := NewAdminUser(testUserID) user := NewAdminUser(testUserID)
user.PATs = map[string]*PersonalAccessToken{"testtoken": { user.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken", ID: "testtoken",
Name: "test token", Name: "test token",
}} }}
account := newAccountWithId(context.Background(), "account_id", testUserID, "") err = newAccountWithId(context.Background(), store, accountID, testUserID, "")
require.NoError(t, err, "failed to create account")
setupKey, _ := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID
account.Peers["testpeer"] = &nbpeer.Peer{ err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
Key: "peerkey", require.NoError(t, err, "failed to save setup key")
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
account.Users[testUserID] = user
err = store.SaveAccount(context.Background(), account) err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
require.NoError(t, err) ID: "testingpeer",
AccountID: accountID,
Key: "peerkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
})
require.NoError(t, err, "failed to save peer")
if len(store.GetAllAccounts(context.Background())) != 1 { err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
ID: "testtoken",
UserID: testUserID,
Name: "test token",
})
require.NoError(t, err, "failed to save personal access token")
accountIDs, err := store.GetAllAccountIDs(context.Background(), LockingStrengthUpdate)
require.NoError(t, err, "failed to get all account ids")
if len(accountIDs) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
} }
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "failed to get account")
err = store.DeleteAccount(context.Background(), account) err = store.DeleteAccount(context.Background(), account)
require.NoError(t, err) require.NoError(t, err)
@@ -908,7 +995,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) {
// save status of non-existing peer // save status of non-existing peer
newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus)
assert.Error(t, err) assert.Error(t, err)
// save new status of existing peer // save new status of existing peer
@@ -924,7 +1011,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) {
err = store.SaveAccount(context.Background(), account) err = store.SaveAccount(context.Background(), account)
require.NoError(t, err) require.NoError(t, err)
err = store.SavePeerStatus(account.Id, "testpeer", newStatus) err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus)
require.NoError(t, err) require.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id) account, err = store.GetAccount(context.Background(), account.Id)
@@ -967,9 +1054,9 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
hashed := "SoMeHaShEdToKeN" hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003" id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed) pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, hashed)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, id, token) require.Equal(t, id, pat.ID)
} }
func TestPostgresql_GetUserByTokenID(t *testing.T) { func TestPostgresql_GetUserByTokenID(t *testing.T) {
@@ -984,7 +1071,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
id := "9dj38s35-63fb-11ec-90d6-0242ac120003" id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(context.Background(), id) user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID) require.Equal(t, id, user.PATs[id].ID)
} }
@@ -1047,7 +1134,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
_, err = store.GetAccount(context.Background(), existingAccountID) _, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err) require.NoError(t, err)
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) labels, err := store.GetAccountPeerDNSLabels(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{}, labels) assert.Equal(t, []string{}, labels)
@@ -1059,7 +1146,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
err = store.AddPeerToAccount(context.Background(), peer1) err = store.AddPeerToAccount(context.Background(), peer1)
require.NoError(t, err) require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) labels, err = store.GetAccountPeerDNSLabels(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test"}, labels) assert.Equal(t, []string{"peer1.domain.test"}, labels)
@@ -1071,7 +1158,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
err = store.AddPeerToAccount(context.Background(), peer2) err = store.AddPeerToAccount(context.Background(), peer2)
require.NoError(t, err) require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) labels, err = store.GetAccountPeerDNSLabels(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels) assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
} }
@@ -1181,7 +1268,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
t.Fatal("failed to save group") t.Fatal("failed to save group")
return err return err
} }
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID) group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID)
if err != nil { if err != nil {
t.Fatal("failed to get group") t.Fatal("failed to get group")
return err return err
@@ -1201,7 +1288,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
account, err := store.GetAccount(context.Background(), accountID) account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err) require.NoError(t, err)
users, err := store.GetAccountUsers(context.Background(), accountID) users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, users, len(account.Users)) require.Len(t, users, len(account.Users))
} }
@@ -1218,7 +1305,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
domain := "example.com" domain := "example.com"
category := "public" category := "public"
IsDomainPrimaryAccount := false IsDomainPrimaryAccount := false
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, accountID, domain, category, &IsDomainPrimaryAccount)
require.NoError(t, err) require.NoError(t, err)
account, err := store.GetAccount(context.Background(), accountID) account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err) require.NoError(t, err)
@@ -1232,7 +1319,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
domain := "test.com" domain := "test.com"
category := "private" category := "private"
IsDomainPrimaryAccount := true IsDomainPrimaryAccount := true
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, accountID, domain, category, &IsDomainPrimaryAccount)
require.NoError(t, err) require.NoError(t, err)
account, err := store.GetAccount(context.Background(), accountID) account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err) require.NoError(t, err)
@@ -1246,7 +1333,9 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
domain := "test.com" domain := "test.com"
category := "private" category := "private"
IsDomainPrimaryAccount := true IsDomainPrimaryAccount := true
err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount) err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, "non-existing-account-id",
domain, category, &IsDomainPrimaryAccount,
)
require.Error(t, err) require.Error(t, err)
}) })
@@ -1274,7 +1363,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID) err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, setupKeyID)
require.NoError(t, err) require.NoError(t, err)
_, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID) _, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID)
@@ -1290,6 +1379,6 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
nonExistingKeyID := "non-existing-key-id" nonExistingKeyID := "non-existing-key-id"
err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID) err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID)
require.Error(t, err) require.Error(t, err)
} }

View File

@@ -83,8 +83,8 @@ func NewPeerNotFoundError(peerKey string) error {
} }
// NewAccountNotFoundError creates a new Error with NotFound type for a missing account // NewAccountNotFoundError creates a new Error with NotFound type for a missing account
func NewAccountNotFoundError(accountKey string) error { func NewAccountNotFoundError() error {
return Errorf(NotFound, "account not found: %s", accountKey) return Errorf(NotFound, "account not found")
} }
// NewUserNotFoundError creates a new Error with NotFound type for a missing user // NewUserNotFoundError creates a new Error with NotFound type for a missing user
@@ -102,23 +102,38 @@ func NewPeerLoginExpiredError() error {
return Errorf(PermissionDenied, "peer login has expired, please log in once more") return Errorf(PermissionDenied, "peer login has expired, please log in once more")
} }
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
func NewSetupKeyNotFoundError(err error) error {
return Errorf(NotFound, "setup key not found: %s", err)
}
func NewGetAccountFromStoreError(err error) error { func NewGetAccountFromStoreError(err error) error {
return Errorf(Internal, "issue getting account from store: %s", err) return Errorf(Internal, "issue getting account from store: %s", err)
} }
func NewUnauthorizedToViewAccountSettingError() error {
return Errorf(PermissionDenied, "only users with admin power can view account settings")
}
// NewUserNotPartOfAccountError creates a new Error with PermissionDenied type for a user not being part of an account
func NewUserNotPartOfAccountError() error {
return Errorf(PermissionDenied, "user is not part of this account")
}
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store // NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
func NewGetUserFromStoreError() error { func NewGetUserFromStoreError() error {
return Errorf(Internal, "issue getting user from store") return Errorf(Internal, "issue getting user from store")
} }
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context // NewAdminPermissionError creates a new Error with PermissionDenied type for actions requiring admin role.
func NewStoreContextCanceledError(duration time.Duration) error { func NewAdminPermissionError() error {
return Errorf(Internal, "store access: context canceled after %v", duration) return Errorf(PermissionDenied, "admin role required to perform this action")
}
// NewOwnerDeletePermissionError creates a new Error with PermissionDenied type for attempting
// to delete a user with the owner role.
func NewOwnerDeletePermissionError() error {
return Errorf(PermissionDenied, "can't delete a user with the owner role")
}
// NewServiceUserRoleInvalidError creates a new Error with InvalidArgument type for creating a service user with owner role
func NewServiceUserRoleInvalidError() error {
return Errorf(InvalidArgument, "can't create a service user with owner role")
} }
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key // NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
@@ -126,7 +141,24 @@ func NewInvalidKeyIDError() error {
return Errorf(InvalidArgument, "invalid key ID") return Errorf(InvalidArgument, "invalid key ID")
} }
// NewUnauthorizedToViewSetupKeysError creates a new Error with Unauthorized type for an issue getting a setup key // NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
func NewUnauthorizedToViewSetupKeysError() error { func NewSetupKeyNotFoundError(err error) error {
return Errorf(Unauthorized, "only users with admin power can view setup keys") return Errorf(NotFound, "setup key not found: %s", err)
}
func NewPATNotFoundError() error {
return Errorf(NotFound, "PAT not found")
}
func NewGetPATFromStoreError() error {
return Errorf(Internal, "issue getting pat from store")
}
func NewUnauthorizedToViewNSGroupsError() error {
return Errorf(PermissionDenied, "only users with admin power can view name server groups")
}
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
func NewStoreContextCanceledError(duration time.Duration) error {
return Errorf(Internal, "store access: context canceled after %v", duration)
} }

View File

@@ -47,65 +47,95 @@ type Store interface {
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByUser(ctx context.Context, userID string) (*Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
GetAllAccountIDs(ctx context.Context, lockStrength LockingStrength) ([]string, error)
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByUserID(userID string) (string, error) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error)
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error
CreateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error
SaveAccount(ctx context.Context, account *Account) error SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain *bool) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error)
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) GetAccountPeerDNSLabels(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
SavePeerLocation(accountID string, peer *nbpeer.Peer) error GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error)
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) (*dns.NameServerGroup, error)
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error)
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error)
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, accountId string) error IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
GetInstallationID() string GetInstallationID() string
@@ -124,7 +154,6 @@ type Store interface {
// This is also a method of metrics.DataSource interface. // This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine GetStoreEngine() StoreEngine
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
DeleteSetupKey(ctx context.Context, accountID, keyID string) error
} }
type StoreEngine string type StoreEngine string

View File

@@ -32,4 +32,7 @@ INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-3465300
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfvprsrlo1hqoo49ohog", "cg3161rlo1hs9cq94gdg", "cg05lnblo1hkg2j514p0"]',0,'');
INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]');
INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL);
INSERT INTO installations VALUES(1,''); INSERT INTO installations VALUES(1,'');

File diff suppressed because it is too large Load Diff

View File

@@ -43,37 +43,34 @@ const (
func TestUser_CreatePAT_ForSameUser(t *testing.T) { func TestUser_CreatePAT_ForSameUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, eventStore: &activity.InMemoryEventStore{},
} }
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) newPAT, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
if err != nil { if err != nil {
t.Fatalf("Error when adding PAT to user: %s", err) t.Fatalf("Error when adding PAT to user: %s", err)
} }
assert.Equal(t, pat.CreatedBy, mockUserID) assert.Equal(t, newPAT.CreatedBy, mockUserID)
tokenID, err := am.Store.GetTokenIDByHashedToken(context.Background(), pat.HashedToken) pat, err := am.Store.GetPATByHashedToken(context.Background(), LockingStrengthShare, newPAT.HashedToken)
if err != nil { if err != nil {
t.Fatalf("Error when getting token ID by hashed token: %s", err) t.Fatalf("Error when getting token ID by hashed token: %s", err)
} }
if tokenID == "" { if pat.ID == "" {
t.Fatal("GetTokenIDByHashedToken failed after adding PAT") t.Fatal("GetTokenIDByHashedToken failed after adding PAT")
} }
assert.Equal(t, pat.ID, tokenID) assert.Equal(t, newPAT.ID, pat.ID)
user, err := am.Store.GetUserByTokenID(context.Background(), tokenID) user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, pat.ID)
if err != nil { if err != nil {
t.Fatalf("Error when getting user by token ID: %s", err) t.Fatalf("Error when getting user by token ID: %s", err)
} }
@@ -84,15 +81,16 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockTargetUserId] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: mockTargetUserId, Id: mockTargetUserId,
AccountID: mockAccountID,
IsServiceUser: false, IsServiceUser: false,
} })
err := store.SaveAccount(context.Background(), account) assert.NoError(t, err, "failed to create user")
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -106,15 +104,16 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
func TestUser_CreatePAT_ForServiceUser(t *testing.T) { func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockTargetUserId] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: mockTargetUserId, Id: mockTargetUserId,
AccountID: mockAccountID,
IsServiceUser: true, IsServiceUser: true,
} })
err := store.SaveAccount(context.Background(), account) assert.NoError(t, err, "failed to create user")
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -132,12 +131,9 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -151,12 +147,9 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
func TestUser_CreatePAT_WithEmptyName(t *testing.T) { func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -164,26 +157,22 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
} }
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn) _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn)
assert.Errorf(t, err, "Wrong expiration should thorw error") assert.Errorf(t, err, "Wrong expiration should throw error")
} }
func TestUser_DeletePAT(t *testing.T) { func TestUser_DeletePAT(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
Id: mockUserID, assert.NoError(t, err, "failed to create account")
PATs: map[string]*PersonalAccessToken{
mockTokenID1: { err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
ID: mockTokenID1, ID: mockTokenID1,
HashedToken: mockToken1, UserID: mockUserID,
}, HashedToken: mockToken1,
}, })
} assert.NoError(t, err, "failed to create PAT")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -195,7 +184,7 @@ func TestUser_DeletePAT(t *testing.T) {
t.Fatalf("Error when adding PAT to user: %s", err) t.Fatalf("Error when adding PAT to user: %s", err)
} }
account, err = store.GetAccount(context.Background(), mockAccountID) account, err := store.GetAccount(context.Background(), mockAccountID)
if err != nil { if err != nil {
t.Fatalf("Error when getting account: %s", err) t.Fatalf("Error when getting account: %s", err)
} }
@@ -206,21 +195,16 @@ func TestUser_DeletePAT(t *testing.T) {
func TestUser_GetPAT(t *testing.T) { func TestUser_GetPAT(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
Id: mockUserID, assert.NoError(t, err, "failed to create account")
AccountID: mockAccountID,
PATs: map[string]*PersonalAccessToken{ err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
mockTokenID1: { ID: mockTokenID1,
ID: mockTokenID1, UserID: mockUserID,
HashedToken: mockToken1, HashedToken: mockToken1,
}, })
}, assert.NoError(t, err, "failed to create PAT")
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -239,25 +223,23 @@ func TestUser_GetPAT(t *testing.T) {
func TestUser_GetAllPATs(t *testing.T) { func TestUser_GetAllPATs(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
Id: mockUserID, assert.NoError(t, err, "failed to create account")
AccountID: mockAccountID,
PATs: map[string]*PersonalAccessToken{ err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
mockTokenID1: { ID: mockTokenID1,
ID: mockTokenID1, UserID: mockUserID,
HashedToken: mockToken1, HashedToken: mockToken1,
}, })
mockTokenID2: { assert.NoError(t, err, "failed to create PAT")
ID: mockTokenID2,
HashedToken: mockToken2, err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
}, ID: mockTokenID2,
}, UserID: mockUserID,
} HashedToken: mockToken2,
err := store.SaveAccount(context.Background(), account) })
if err != nil { assert.NoError(t, err, "failed to create PAT")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -342,12 +324,9 @@ func validateStruct(s interface{}) (err error) {
func TestUser_CreateServiceUser(t *testing.T) { func TestUser_CreateServiceUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -359,7 +338,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
t.Fatalf("Error when creating service user: %s", err) t.Fatalf("Error when creating service user: %s", err)
} }
account, err = store.GetAccount(context.Background(), mockAccountID) account, err := store.GetAccount(context.Background(), mockAccountID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 2, len(account.Users)) assert.Equal(t, 2, len(account.Users))
@@ -383,12 +362,9 @@ func TestUser_CreateServiceUser(t *testing.T) {
func TestUser_CreateUser_ServiceUser(t *testing.T) { func TestUser_CreateUser_ServiceUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -406,7 +382,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
t.Fatalf("Error when creating user: %s", err) t.Fatalf("Error when creating user: %s", err)
} }
account, err = store.GetAccount(context.Background(), mockAccountID) account, err := store.GetAccount(context.Background(), mockAccountID)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, user.IsServiceUser) assert.True(t, user.IsServiceUser)
@@ -425,12 +401,9 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
func TestUser_CreateUser_RegularUser(t *testing.T) { func TestUser_CreateUser_RegularUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -450,12 +423,9 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
func TestUser_InviteNewUser(t *testing.T) { func TestUser_InviteNewUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -549,13 +519,12 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = tt.serviceUser
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
} err = store.SaveUser(context.Background(), LockingStrengthUpdate, tt.serviceUser)
assert.NoError(t, err, "failed to create service user")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -582,12 +551,9 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
func TestUser_DeleteUser_SelfDelete(t *testing.T) { func TestUser_DeleteUser_SelfDelete(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -603,39 +569,38 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
func TestUser_DeleteUser_regularUser(t *testing.T) { func TestUser_DeleteUser_regularUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
targetId := "user2" err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
account.Users[targetId] = &User{ assert.NoError(t, err, "failed to create account")
Id: targetId,
IsServiceUser: true,
ServiceUserName: "user2username",
}
targetId = "user3"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
}
targetId = "user4"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedIntegration,
}
targetId = "user5" err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
account.Users[targetId] = &User{ {
Id: targetId, Id: "user2",
IsServiceUser: false, AccountID: mockAccountID,
Issued: UserIssuedAPI, IsServiceUser: true,
Role: UserRoleOwner, ServiceUserName: "user2username",
} },
{
err := store.SaveAccount(context.Background(), account) Id: "user3",
if err != nil { AccountID: mockAccountID,
t.Fatalf("Error when saving account: %s", err) IsServiceUser: false,
} Issued: UserIssuedAPI,
},
{
Id: "user4",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedIntegration,
},
{
Id: "user5",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleOwner,
},
})
assert.NoError(t, err, "failed to save users")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -685,61 +650,64 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
func TestUser_DeleteUser_RegularUsers(t *testing.T) { func TestUser_DeleteUser_RegularUsers(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
targetId := "user2" err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
account.Users[targetId] = &User{ assert.NoError(t, err, "failed to create account")
Id: targetId,
IsServiceUser: true,
ServiceUserName: "user2username",
}
targetId = "user3"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
}
targetId = "user4"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedIntegration,
}
targetId = "user5" err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
account.Users[targetId] = &User{ {
Id: targetId, Id: "user2",
IsServiceUser: false, AccountID: mockAccountID,
Issued: UserIssuedAPI, IsServiceUser: true,
Role: UserRoleOwner, ServiceUserName: "user2username",
} },
account.Users["user6"] = &User{ {
Id: "user6", Id: "user3",
IsServiceUser: false, AccountID: mockAccountID,
Issued: UserIssuedAPI, IsServiceUser: false,
} Issued: UserIssuedAPI,
account.Users["user7"] = &User{ },
Id: "user7", {
IsServiceUser: false, Id: "user4",
Issued: UserIssuedAPI, AccountID: mockAccountID,
} IsServiceUser: false,
account.Users["user8"] = &User{ Issued: UserIssuedIntegration,
Id: "user8", },
IsServiceUser: false, {
Issued: UserIssuedAPI, Id: "user5",
Role: UserRoleAdmin, AccountID: mockAccountID,
} IsServiceUser: false,
account.Users["user9"] = &User{ Issued: UserIssuedAPI,
Id: "user9", Role: UserRoleOwner,
IsServiceUser: false, },
Issued: UserIssuedAPI, {
Role: UserRoleAdmin, Id: "user6",
} AccountID: mockAccountID,
IsServiceUser: false,
err := store.SaveAccount(context.Background(), account) Issued: UserIssuedAPI,
if err != nil { },
t.Fatalf("Error when saving account: %s", err) {
} Id: "user7",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
},
{
Id: "user8",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
},
{
Id: "user9",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
},
})
assert.NoError(t, err)
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -816,7 +784,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
acc, err := am.Store.GetAccount(context.Background(), account.Id) acc, err := am.Store.GetAccount(context.Background(), mockAccountID)
assert.NoError(t, err) assert.NoError(t, err)
for _, id := range tc.expectedDeleted { for _, id := range tc.expectedDeleted {
@@ -836,12 +804,9 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
func TestDefaultAccountManager_GetUser(t *testing.T) { func TestDefaultAccountManager_GetUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -865,14 +830,19 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
func TestDefaultAccountManager_ListUsers(t *testing.T) { func TestDefaultAccountManager_ListUsers(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewRegularUser("normal_user1")
account.Users["normal_user2"] = NewRegularUser("normal_user2")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
} newUser := NewRegularUser("normal_user1")
newUser.AccountID = mockAccountID
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
assert.NoError(t, err, "failed to create user")
newUser = NewRegularUser("normal_user2")
newUser.AccountID = mockAccountID
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
assert.NoError(t, err, "failed to create user")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -946,15 +916,24 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
delete(account.Users, mockUserID)
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
} newUser := NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
assert.NoError(t, err, "failed to create user")
settings, err := store.GetAccountSettings(context.Background(), LockingStrengthShare, mockAccountID)
assert.NoError(t, err, "failed to get account settings")
settings.RegularUsersViewBlocked = testCase.limitedViewSettings
err = store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, mockAccountID, settings)
assert.NoError(t, err, "failed to save account settings")
err = store.DeleteUser(context.Background(), LockingStrengthUpdate, mockAccountID, mockUserID)
assert.NoError(t, err, "failed to delete user")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -968,7 +947,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
assert.Equal(t, 1, len(users)) assert.Equal(t, 1, len(users))
userInfo, _ := users[0].ToUserInfo(nil, account.Settings) userInfo, _ := users[0].ToUserInfo(nil, settings)
assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView) assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView)
}) })
} }
@@ -978,22 +957,21 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
func TestDefaultAccountManager_ExternalCache(t *testing.T) { func TestDefaultAccountManager_ExternalCache(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
externalUser := &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
Id: "externalUser", assert.NoError(t, err, "failed to create account")
Role: UserRoleUser,
Issued: UserIssuedIntegration, err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: "externalUser",
AccountID: mockAccountID,
Role: UserRoleUser,
Issued: UserIssuedIntegration,
IntegrationReference: integration_reference.IntegrationReference{ IntegrationReference: integration_reference.IntegrationReference{
ID: 1, ID: 1,
IntegrationType: "external", IntegrationType: "external",
}, },
} })
account.Users[externalUser.Id] = externalUser assert.NoError(t, err, "failed to create user")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -1013,6 +991,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
cacheManager := am.GetExternalCacheManager() cacheManager := am.GetExternalCacheManager()
externalUser, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, "externalUser")
assert.NoError(t, err, "failed to get user")
cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id) cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}) err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"})
assert.NoError(t, err) assert.NoError(t, err)
@@ -1042,17 +1024,17 @@ func TestUser_IsAdmin(t *testing.T) {
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: mockServiceUserID, Id: mockServiceUserID,
AccountID: mockAccountID,
Role: "user", Role: "user",
IsServiceUser: true, IsServiceUser: true,
} })
assert.NoError(t, err, "failed to create user")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -1071,17 +1053,16 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{ assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: mockServiceUserID, Id: mockServiceUserID,
AccountID: mockAccountID,
Role: "user", Role: "user",
IsServiceUser: true, IsServiceUser: true,
} })
assert.NoError(t, err, "failed to create user")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@@ -1240,21 +1221,30 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
// create an account and an admin user // create an account and an admin user
account, err := manager.GetOrCreateAccountByUser(context.Background(), ownerUserID, "netbird.io") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), ownerUserID, "netbird.io")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// create other users // create other users
account.Users[regularUserID] = NewRegularUser(regularUserID) regularUser := NewRegularUser(regularUserID)
account.Users[adminUserID] = NewAdminUser(adminUserID) regularUser.AccountID = accountID
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
err = manager.Store.SaveAccount(context.Background(), account) adminUser := NewAdminUser(adminUserID)
if err != nil { adminUser.AccountID = accountID
t.Fatal(err)
serviceUser := &User{
Id: serviceUserID,
AccountID: accountID,
IsServiceUser: true,
Role: UserRoleAdmin,
ServiceUserName: "service",
} }
updated, err := manager.SaveUser(context.Background(), account.Id, tc.initiatorID, tc.update) err = manager.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{regularUser, adminUser, serviceUser})
assert.NoError(t, err, "failed to save users")
updated, err := manager.SaveUser(context.Background(), accountID, tc.initiatorID, tc.update)
if tc.expectedErr { if tc.expectedErr {
require.Errorf(t, err, "expecting SaveUser to throw an error") require.Errorf(t, err, "expecting SaveUser to throw an error")
} else { } else {

View File

@@ -88,18 +88,18 @@ type Route struct {
// AccountID is a reference to Account that this object belongs // AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
// Network and Domains are mutually exclusive // Network and Domains are mutually exclusive
Network netip.Prefix `gorm:"serializer:json"` Network netip.Prefix `gorm:"serializer:json"`
Domains domain.List `gorm:"serializer:json"` Domains domain.List `gorm:"serializer:json"`
KeepRoute bool KeepRoute bool
NetID NetID NetID NetID
Description string Description string
Peer string Peer string
PeerGroups []string `gorm:"serializer:json"` PeerGroups []string `gorm:"serializer:json"`
NetworkType NetworkType NetworkType NetworkType
Masquerade bool Masquerade bool
Metric int Metric int
Enabled bool Enabled bool
Groups []string `gorm:"serializer:json"` Groups []string `gorm:"serializer:json"`
AccessControlGroups []string `gorm:"serializer:json"` AccessControlGroups []string `gorm:"serializer:json"`
} }
@@ -111,19 +111,20 @@ func (r *Route) EventMeta() map[string]any {
// Copy copies a route object // Copy copies a route object
func (r *Route) Copy() *Route { func (r *Route) Copy() *Route {
route := &Route{ route := &Route{
ID: r.ID, ID: r.ID,
Description: r.Description, AccountID: r.AccountID,
NetID: r.NetID, Description: r.Description,
Network: r.Network, NetID: r.NetID,
Domains: slices.Clone(r.Domains), Network: r.Network,
KeepRoute: r.KeepRoute, Domains: slices.Clone(r.Domains),
NetworkType: r.NetworkType, KeepRoute: r.KeepRoute,
Peer: r.Peer, NetworkType: r.NetworkType,
PeerGroups: slices.Clone(r.PeerGroups), Peer: r.Peer,
Metric: r.Metric, PeerGroups: slices.Clone(r.PeerGroups),
Masquerade: r.Masquerade, Metric: r.Metric,
Enabled: r.Enabled, Masquerade: r.Masquerade,
Groups: slices.Clone(r.Groups), Enabled: r.Enabled,
Groups: slices.Clone(r.Groups),
AccessControlGroups: slices.Clone(r.AccessControlGroups), AccessControlGroups: slices.Clone(r.AccessControlGroups),
} }
return route return route
@@ -138,6 +139,7 @@ func (r *Route) IsEqual(other *Route) bool {
} }
return other.ID == r.ID && return other.ID == r.ID &&
other.AccountID == r.AccountID &&
other.Description == r.Description && other.Description == r.Description &&
other.NetID == r.NetID && other.NetID == r.NetID &&
other.Network == r.Network && other.Network == r.Network &&
@@ -149,7 +151,7 @@ func (r *Route) IsEqual(other *Route) bool {
other.Masquerade == r.Masquerade && other.Masquerade == r.Masquerade &&
other.Enabled == r.Enabled && other.Enabled == r.Enabled &&
slices.Equal(r.Groups, other.Groups) && slices.Equal(r.Groups, other.Groups) &&
slices.Equal(r.PeerGroups, other.PeerGroups)&& slices.Equal(r.PeerGroups, other.PeerGroups) &&
slices.Equal(r.AccessControlGroups, other.AccessControlGroups) slices.Equal(r.AccessControlGroups, other.AccessControlGroups)
} }