diff --git a/go.mod b/go.mod index 2b4111ce3..81d829462 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 + github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 35abe82d2..ac496ce0a 100644 --- a/go.sum +++ b/go.sum @@ -521,8 +521,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 h1:L8mNd3tBxMdnQNxMNJ+/EiwHwizNOMy8/nHLVGNfjpg= -github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= +github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480 h1:M+UPn/o+plVE7ZehgL6/1dftptsO1tyTPssgImgi+28= +github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480/go.mod h1:RC0PnyATSBPrRWKQgb+7KcC1tMta9eYyzuA414RG9wQ= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= diff --git a/management/server/account.go b/management/server/account.go index 79f9b3422..9e91a54b4 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -29,7 +29,6 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/geolocation" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrated_validator" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -101,11 +100,11 @@ type AccountManager interface { GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) - GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error) - GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) - GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) - SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error - SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error + GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) + GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) + SaveGroup(ctx context.Context, accountID, userID string, group *types.Group) error + SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error @@ -199,8 +198,8 @@ type DefaultAccountManager struct { // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. // Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, // newly groups to create and an error if any occurred. -func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) { - existedGroupsByName := make(map[string]*nbgroup.Group) +func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups []*types.Group, groupNames []string) (bool, []string, []*types.Group, error) { + existedGroupsByName := make(map[string]*types.Group) for _, group := range groups { existedGroupsByName[group.Name] = group } @@ -215,21 +214,21 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups [] return false, nil, nil, nil } - newGroupsToCreate := make([]*nbgroup.Group, 0) + newGroupsToCreate := make([]*types.Group, 0) var modified bool for _, name := range groupsToAdd { group, exists := existedGroupsByName[name] if !exists { - group = &nbgroup.Group{ + group = &types.Group{ ID: xid.New().String(), AccountID: user.AccountID, Name: name, - Issued: nbgroup.GroupIssuedJWT, + Issued: types.GroupIssuedJWT, } newGroupsToCreate = append(newGroupsToCreate, group) } - if group.Issued == nbgroup.GroupIssuedJWT { + if group.Issued == types.GroupIssuedJWT { newUserAutoGroups = append(newUserAutoGroups, group.ID) modified = true } @@ -1323,7 +1322,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error getting account groups: %w", err) } - groupsMap := make(map[string]*nbgroup.Group, len(groups)) + groupsMap := make(map[string]*types.Group, len(groups)) for _, group := range groups { groupsMap[group.ID] = group } @@ -1741,15 +1740,15 @@ func (am *DefaultAccountManager) GetUserManager() users.Manager { // addAllGroup to account object if it doesn't exist func addAllGroup(account *types.Account) error { if len(account.Groups) == 0 { - allGroup := &nbgroup.Group{ + allGroup := &types.Group{ ID: xid.New().String(), Name: "All", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, } for _, peer := range account.Peers { allGroup.Peers = append(allGroup.Peers, peer.ID) } - account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup} + account.Groups = map[string]*types.Group{allGroup.ID: allGroup} id := xid.New().String() @@ -1863,18 +1862,18 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { // separateGroups separates user's auto groups into non-JWT and JWT groups. // Returns the list of standard auto groups and a map of JWT auto groups, // where the keys are the group names and the values are the group IDs. -func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) { +func separateGroups(autoGroups []string, allGroups []*types.Group) ([]string, map[string]string) { newAutoGroups := make([]string, 0) jwtAutoGroups := make(map[string]string) // map of group name to group ID - allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups)) + allGroupsMap := make(map[string]*types.Group, len(allGroups)) for _, group := range allGroups { allGroupsMap[group.ID] = group } for _, id := range autoGroups { if group, ok := allGroupsMap[id]; ok { - if group.Issued == nbgroup.GroupIssuedJWT { + if group.Issued == types.GroupIssuedJWT { jwtAutoGroups[group.Name] = id } else { newAutoGroups = append(newAutoGroups, id) diff --git a/management/server/account_test.go b/management/server/account_test.go index 32cc1f290..ca8f21963 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -29,7 +29,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" @@ -53,7 +52,7 @@ func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.P } return update, false, nil } -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { validatedPeers := make(map[string]struct{}) for _, peer := range peers { validatedPeers[peer.ID] = struct{}{} @@ -740,7 +739,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.Len(t, account.Groups, 3, "groups should be added to the account") - groupsByNames := map[string]*group.Group{} + groupsByNames := map[string]*types.Group{} for _, g := range account.Groups { groupsByNames[g.Name] = g } @@ -748,12 +747,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { g1, ok := groupsByNames["group1"] require.True(t, ok, "group1 should be added to the account") require.Equal(t, g1.Name, "group1", "group1 name should match") - require.Equal(t, g1.Issued, group.GroupIssuedJWT, "group1 issued should match") + require.Equal(t, g1.Issued, types.GroupIssuedJWT, "group1 issued should match") g2, ok := groupsByNames["group2"] require.True(t, ok, "group2 should be added to the account") require.Equal(t, g2.Name, "group2", "group2 name should match") - require.Equal(t, g2.Issued, group.GroupIssuedJWT, "group2 issued should match") + require.Equal(t, g2.Issued, types.GroupIssuedJWT, "group2 issued should match") }) } @@ -1248,7 +1247,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - group := group.Group{ + group := types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{}, @@ -1325,7 +1324,7 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { manager, account, peer1, peer2, _ := setupNetworkMapTest(t) - group := group.Group{ + group := types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID}, @@ -1373,7 +1372,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { manager, account, peer1, _, peer3 := setupNetworkMapTest(t) - group := group.Group{ + group := types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer3.ID}, @@ -1429,7 +1428,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -1656,7 +1655,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { Peers: map[string]*nbpeer.Peer{ "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, }, - Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, + Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, Routes: map[route.ID]*route.Route{ "route-1": { ID: "route-1", @@ -1757,10 +1756,11 @@ func TestAccount_Copy(t *testing.T) { }, }, }, - Groups: map[string]*group.Group{ + Groups: map[string]*types.Group{ "group1": { - ID: "group1", - Peers: []string{"peer1"}, + ID: "group1", + Peers: []string{"peer1"}, + Resources: []types.Resource{}, }, }, Policies: []*types.Policy{ @@ -2717,8 +2717,8 @@ func TestAccount_SetJWTGroups(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*group.Group{ - "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*types.Group{ + "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}}, }, Settings: &types.Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"}, Users: map[string]*types.User{ @@ -2756,7 +2756,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") - assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") + assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") }) t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { @@ -2776,7 +2776,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") - assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") + assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") }) t.Run("add jwt group", func(t *testing.T) { @@ -2846,10 +2846,10 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*group.Group{ - "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, - "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{}}, - "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*types.Group{ + "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}}, + "group2": {ID: "group2", Name: "group2", Issued: types.GroupIssuedAPI, Peers: []string{}}, + "group3": {ID: "group3", Name: "group3", Issued: types.GroupIssuedAPI, Peers: []string{}}, }, Users: map[string]*types.User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } @@ -2882,10 +2882,10 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*group.Group{ - "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, - "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, - "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, + Groups: map[string]*types.Group{ + "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, + "group2": {ID: "group2", Name: "group2", Issued: types.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, + "group3": {ID: "group3", Name: "group3", Issued: types.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, }, Users: map[string]*types.User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 958034ae0..6fb9f6a29 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -19,7 +19,6 @@ import ( "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" ) @@ -295,13 +294,13 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account return nil, err } - newGroup1 := &group.Group{ + newGroup1 := &types.Group{ ID: dnsGroup1ID, Peers: []string{peer1.ID}, Name: dnsGroup1ID, } - newGroup2 := &group.Group{ + newGroup2 := &types.Group{ ID: dnsGroup2ID, Name: dnsGroup2ID, } @@ -485,7 +484,7 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { func TestDNSAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -552,7 +551,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { // Creating DNS settings with groups that have peers should update account peers and send peer update t.Run("creating dns setting with used groups", func(t *testing.T) { - err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, diff --git a/management/server/group.go b/management/server/group.go index 068e845c3..cd228af65 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -16,7 +16,6 @@ import ( "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" ) @@ -48,7 +47,7 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco } // GetGroup returns a specific group by groupID in an account -func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { +func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } @@ -56,7 +55,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI } // GetAllGroups returns all groups in an account -func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { +func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } @@ -64,21 +63,21 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us } // GetGroupByName filters all groups in an account by name and returns the one with the most peers -func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { +func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName) } // SaveGroup object of the peers -func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error { +func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup}) + return am.SaveGroups(ctx, accountID, userID, []*types.Group{newGroup}) } // SaveGroups adds new groups to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error { +func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err @@ -93,7 +92,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } var eventsToStore []func() - var groupsToSave []*nbgroup.Group + var groupsToSave []*types.Group var updateAccountPeers bool err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -138,7 +137,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } // prepareGroupEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction store.Store, accountID, userID string, newGroup *nbgroup.Group) []func() { +func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction store.Store, accountID, userID string, newGroup *types.Group) []func() { var eventsToStore []func() addedPeers := make([]string, 0) @@ -226,7 +225,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us var allErrors error var groupIDsToDelete []string - var deletedGroups []*nbgroup.Group + var deletedGroups []*types.Group err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { for _, groupID := range groupIDs { @@ -267,7 +266,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - var group *nbgroup.Group + var group *types.Group var updateAccountPeers bool var err error @@ -303,12 +302,53 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return nil } +// GroupAddResource appends resource to the group +func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var group *types.Group + var updateAccountPeers bool + var err error + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) + if err != nil { + return err + } + + if updated := group.AddResource(resource); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + }) + if err != nil { + return err + } + + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + + return nil +} + // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - var group *nbgroup.Group + var group *types.Group var updateAccountPeers bool var err error @@ -344,13 +384,54 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return nil } +// GroupDeleteResource removes resource from the group +func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var group *types.Group + var updateAccountPeers bool + var err error + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) + if err != nil { + return err + } + + if updated := group.RemoveResource(resource); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + }) + if err != nil { + return err + } + + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + + return nil +} + // validateNewGroup validates the new group for existence and required fields. -func validateNewGroup(ctx context.Context, transaction store.Store, accountID string, newGroup *nbgroup.Group) error { - if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { +func validateNewGroup(ctx context.Context, transaction store.Store, accountID string, newGroup *types.Group) error { + if newGroup.ID == "" && newGroup.Issued != types.GroupIssuedAPI { return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) } - if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { + if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI { existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name) if err != nil { if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { @@ -377,9 +458,9 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st return nil } -func validateDeleteGroup(ctx context.Context, transaction store.Store, group *nbgroup.Group, userID string) error { +func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error { // disable a deleting integration group if the initiator is not an admin service user - if group.Issued == nbgroup.GroupIssuedIntegration { + if group.Issued == types.GroupIssuedIntegration { executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err @@ -417,7 +498,7 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *nb } // checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. -func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *nbgroup.Group) error { +func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error { dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID) if err != nil { return err diff --git a/management/server/group_test.go b/management/server/group_test.go index f46c310db..834388d1e 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -12,7 +12,6 @@ import ( "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" @@ -33,22 +32,22 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { t.Error("failed to init testing account") } for _, group := range account.Groups { - group.Issued = nbgroup.GroupIssuedIntegration + group.Issued = types.GroupIssuedIntegration err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration) + t.Errorf("should allow to create %s groups", types.GroupIssuedIntegration) } } for _, group := range account.Groups { - group.Issued = nbgroup.GroupIssuedJWT + group.Issued = types.GroupIssuedJWT err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT) + t.Errorf("should allow to create %s groups", types.GroupIssuedJWT) } } for _, group := range account.Groups { - group.Issued = nbgroup.GroupIssuedAPI + group.Issued = types.GroupIssuedAPI group.ID = "" err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err == nil { @@ -146,13 +145,13 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { manager, account, err := initTestGroupAccount(am) assert.NoError(t, err, "Failed to init testing account") - groups := make([]*nbgroup.Group, 10) + groups := make([]*types.Group, 10) for i := 0; i < 10; i++ { - groups[i] = &nbgroup.Group{ + groups[i] = &types.Group{ ID: fmt.Sprintf("group-%d", i+1), AccountID: account.Id, Name: fmt.Sprintf("group-%d", i+1), - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, } } @@ -272,59 +271,59 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t accountID := "testingAcc" domain := "example.com" - groupForRoute := &nbgroup.Group{ + groupForRoute := &types.Group{ ID: "grp-for-route", AccountID: "account-id", Name: "Group for route", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForRoute2 := &nbgroup.Group{ + groupForRoute2 := &types.Group{ ID: "grp-for-route2", AccountID: "account-id", Name: "Group for route", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForNameServerGroups := &nbgroup.Group{ + groupForNameServerGroups := &types.Group{ ID: "grp-for-name-server-grp", AccountID: "account-id", Name: "Group for name server groups", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForPolicies := &nbgroup.Group{ + groupForPolicies := &types.Group{ ID: "grp-for-policies", AccountID: "account-id", Name: "Group for policies", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForSetupKeys := &nbgroup.Group{ + groupForSetupKeys := &types.Group{ ID: "grp-for-keys", AccountID: "account-id", Name: "Group for setup keys", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForUsers := &nbgroup.Group{ + groupForUsers := &types.Group{ ID: "grp-for-users", AccountID: "account-id", Name: "Group for users", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForIntegration := &nbgroup.Group{ + groupForIntegration := &types.Group{ ID: "grp-for-integration", AccountID: "account-id", Name: "Group for users integration", - Issued: nbgroup.GroupIssuedIntegration, + Issued: types.GroupIssuedIntegration, Peers: make([]string, 0), } @@ -393,7 +392,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t func TestGroupAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -430,7 +429,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1.ID, peer2.ID}, @@ -523,7 +522,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID}, @@ -592,7 +591,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1.ID, peer3.ID}, @@ -633,7 +632,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -660,7 +659,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupD", Name: "GroupD", Peers: []string{peer1.ID}, diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 621875c6c..6a1088141 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -668,6 +668,10 @@ components: description: Count of peers associated to the group type: integer example: 2 + resources_count: + description: Count of resources associated to the group + type: integer + example: 5 issued: description: How the group was issued (api, integration, jwt) type: string @@ -677,6 +681,7 @@ components: - id - name - peers_count + - resources_count GroupRequest: type: object properties: @@ -690,6 +695,10 @@ components: items: type: string example: "ch8i4ug6lnn4g9hqv7m1" + resources: + type: array + items: + $ref: '#/components/schemas/Resource' required: - name Group: @@ -702,8 +711,13 @@ components: type: array items: $ref: '#/components/schemas/PeerMinimum' + resources: + type: array + items: + $ref: '#/components/schemas/Resource' required: - peers + - resources PolicyRuleMinimum: type: object properties: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 2d0adf140..0ffc6eabe 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -379,7 +379,11 @@ type Group struct { Peers []PeerMinimum `json:"peers"` // PeersCount Count of peers associated to the group - PeersCount int `json:"peers_count"` + PeersCount int `json:"peers_count"` + Resources []Resource `json:"resources"` + + // ResourcesCount Count of resources associated to the group + ResourcesCount int `json:"resources_count"` } // GroupIssued How the group was issued (api, integration, jwt) @@ -398,6 +402,9 @@ type GroupMinimum struct { // PeersCount Count of peers associated to the group PeersCount int `json:"peers_count"` + + // ResourcesCount Count of resources associated to the group + ResourcesCount int `json:"resources_count"` } // GroupMinimumIssued How the group was issued (api, integration, jwt) @@ -409,7 +416,8 @@ type GroupRequest struct { Name string `json:"name"` // Peers List of peers ids - Peers *[]string `json:"peers,omitempty"` + Peers *[]string `json:"peers,omitempty"` + Resources *[]Resource `json:"resources,omitempty"` } // Location Describe geographical location information @@ -1068,7 +1076,7 @@ type ProcessCheck struct { // Resource defines model for Resource. type Resource struct { - // Id Resource ID + // Id ID of the resource Id string `json:"id"` Type ResourceType `json:"type"` } diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index e60529cec..ee52d8b4c 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -9,9 +9,9 @@ import ( "github.com/netbirdio/netbird/management/server/http/configs" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -129,10 +129,21 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := nbgroup.Group{ + + resources := make([]types.Resource, 0) + if req.Resources != nil { + for _, res := range *req.Resources { + resource := types.Resource{} + resource.FromAPIRequest(&res) + resources = append(resources, resource) + } + } + + group := types.Group{ ID: groupID, Name: req.Name, Peers: peers, + Resources: resources, Issued: existingGroup.Issued, IntegrationReference: existingGroup.IntegrationReference, } @@ -179,10 +190,21 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := nbgroup.Group{ - Name: req.Name, - Peers: peers, - Issued: nbgroup.GroupIssuedAPI, + + resources := make([]types.Resource, 0) + if req.Resources != nil { + for _, res := range *req.Resources { + resource := types.Resource{} + resource.FromAPIRequest(&res) + resources = append(resources, resource) + } + } + + group := types.Group{ + Name: req.Name, + Peers: peers, + Resources: resources, + Issued: types.GroupIssuedAPI, } err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group) @@ -259,13 +281,19 @@ func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) { } -func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { +func toGroupResponse(peers []*nbpeer.Peer, group *types.Group) *api.Group { peersMap := make(map[string]*nbpeer.Peer, len(peers)) for _, peer := range peers { peersMap[peer.ID] = peer } - cache := make(map[string]api.PeerMinimum) + resMap := make(map[string]types.Resource, len(peers)) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + + peerCache := make(map[string]api.PeerMinimum) + resCache := make(map[string]api.Resource) gr := api.Group{ Id: group.ID, Name: group.Name, @@ -273,7 +301,7 @@ func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { } for _, pid := range group.Peers { - _, ok := cache[pid] + _, ok := peerCache[pid] if !ok { peer, ok := peersMap[pid] if !ok { @@ -283,12 +311,27 @@ func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { Id: peer.ID, Name: peer.Name, } - cache[pid] = peerResp + peerCache[pid] = peerResp gr.Peers = append(gr.Peers, peerResp) } } gr.PeersCount = len(gr.Peers) + for _, res := range group.Resources { + _, ok := resCache[res.ID] + if !ok { + resource, ok := resMap[res.ID] + if !ok { + continue + } + resResp := resource.ToAPIResponse() + resCache[res.ID] = *resResp + gr.Resources = append(gr.Resources, *resResp) + } + } + + gr.ResourcesCount = len(gr.Resources) + return &gr } diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index 089c1a40f..49805ca9b 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -17,13 +17,13 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) var TestPeers = map[string]*nbpeer.Peer{ @@ -31,20 +31,20 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(initGroups ...*nbgroup.Group) *handler { +func initGroupTestData(initGroups ...*types.Group) *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ - SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { + SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group) error { if !strings.HasPrefix(group.ID, "id-") { group.ID = "id-was-set" } return nil }, - GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) { - groups := map[string]*nbgroup.Group{ - "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, - "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, - "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, + GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*types.Group, error) { + groups := map[string]*types.Group{ + "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: types.GroupIssuedJWT}, + "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: types.GroupIssuedAPI}, + "id-all": {ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, } for _, group := range initGroups { @@ -61,9 +61,9 @@ func initGroupTestData(initGroups ...*nbgroup.Group) *handler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) { + GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) { if groupName == "All" { - return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil + return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil } return nil, fmt.Errorf("unknown group name") @@ -120,7 +120,7 @@ func TestGetGroup(t *testing.T) { }, } - group := &nbgroup.Group{ + group := &types.Group{ ID: "idofthegroup", Name: "Group", } @@ -154,7 +154,7 @@ func TestGetGroup(t *testing.T) { t.Fatalf("I don't know what I expected; %v", err) } - got := &nbgroup.Group{} + got := &types.Group{} if err = json.Unmarshal(content, &got); err != nil { t.Fatalf("Sent content is not in correct json format; %v", err) } diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index a7f5a9425..4562766bd 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -10,7 +10,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" @@ -200,7 +199,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { return } - groupsMap := map[string]*nbgroup.Group{} + groupsMap := map[string]*types.Group{} groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) for _, group := range groups { groupsMap[group.ID] = group @@ -325,7 +324,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { +func toGroupsInfo(groups map[string]*types.Group, peerID string) []api.GroupMinimum { groupsInfo := []api.GroupMinimum{} groupsChecked := make(map[string]struct{}) for _, group := range groups { diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index da53e4ad7..83abc1c40 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -15,7 +15,6 @@ import ( "github.com/gorilla/mux" "golang.org/x/exp/maps" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -111,7 +110,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { regularUser: types.NewRegularUser(regularUser), serviceUser: srvUser, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "group1": { ID: "group1", AccountID: accountID, diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index 0255c773b..d538d07db 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -9,7 +9,6 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" @@ -361,8 +360,8 @@ func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, resp) } -func toPolicyResponse(groups []*nbgroup.Group, policy *types.Policy) *api.Policy { - groupsMap := make(map[string]*nbgroup.Group) +func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy { + groupsMap := make(map[string]*types.Group) for _, group := range groups { groupsMap[group.ID] = group } diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go index d8db288d6..956d0b7cd 100644 --- a/management/server/http/handlers/policies/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -10,7 +10,6 @@ import ( "strings" "testing" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" @@ -45,8 +44,8 @@ func initPoliciesTestData(policies ...*types.Policy) *handler { } return policy, nil }, - GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { - return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil + GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*types.Group, error) { + return []*types.Group{{ID: "F"}, {ID: "G"}}, nil }, GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil @@ -59,7 +58,7 @@ func initPoliciesTestData(policies ...*types.Policy) *handler { Policies: []*types.Policy{ {ID: "id-existed"}, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "F": {ID: "F"}, "G": {ID: "G"}, }, diff --git a/management/server/integrated_validator/interface.go b/management/server/integrated_validator/interface.go index 03be9d039..22b8026aa 100644 --- a/management/server/integrated_validator/interface.go +++ b/management/server/integrated_validator/interface.go @@ -4,8 +4,8 @@ import ( "context" "github.com/netbirdio/netbird/management/server/account" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" ) // IntegratedValidator interface exists to avoid the circle dependencies @@ -14,7 +14,7 @@ type IntegratedValidator interface { ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) - GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) + GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) PeerDeleted(ctx context.Context, accountID, peerID string) error SetPeerInvalidationListener(fn func(accountID string)) Stop(ctx context.Context) diff --git a/management/server/management_test.go b/management/server/management_test.go index f3fe6e69c..f0f83a237 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -23,10 +23,10 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) @@ -458,7 +458,7 @@ func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.P return update, false, nil } -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { validatedPeers := make(map[string]struct{}) for p := range peers { validatedPeers[p] = struct{}{} diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index 5cdb1ea64..1d356387f 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -5,7 +5,6 @@ import ( "testing" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" @@ -35,7 +34,7 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { UsedTimes: 1, }, }, - Groups: map[string]*group.Group{ + Groups: map[string]*types.Group{ "1": {}, "2": {}, }, @@ -120,7 +119,7 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { UsedTimes: 1, }, }, - Groups: map[string]*group.Group{ + Groups: map[string]*types.Group{ "1": {}, "2": {}, }, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 37a392c23..45d5eceb6 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -13,7 +13,6 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/networks" @@ -41,11 +40,11 @@ type MockAccountManager struct { GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*group.Group, error) - GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error) - GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error) - SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error - SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error + GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) + GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) + SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group) error + SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error @@ -152,7 +151,7 @@ func (am *MockAccountManager) GetValidatedPeers(account *types.Account) (map[str } // GetGroup mock implementation of GetGroup from server.AccountManager interface -func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, userID string) (*group.Group, error) { +func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupFunc(ctx, accountId, groupID, userID) } @@ -160,7 +159,7 @@ func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, } // GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface -func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*group.Group, error) { +func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { if am.GetAllGroupsFunc != nil { return am.GetAllGroupsFunc(ctx, accountID, userID) } @@ -327,7 +326,7 @@ func (am *MockAccountManager) AddPeer( } // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface -func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*group.Group, error) { +func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupByNameFunc(ctx, accountID, groupName) } @@ -335,7 +334,7 @@ func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, gro } // SaveGroup mock implementation of SaveGroup from server.AccountManager interface -func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *group.Group) error { +func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *types.Group) error { if am.SaveGroupFunc != nil { return am.SaveGroupFunc(ctx, accountID, userID, group) } @@ -343,7 +342,7 @@ func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID s } // SaveGroups mock implementation of SaveGroups from server.AccountManager interface -func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*group.Group) error { +func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { if am.SaveGroupsFunc != nil { return am.SaveGroupsFunc(ctx, accountID, userID, groups) } diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 31774320f..19acdf1ba 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -11,9 +11,9 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` @@ -306,7 +306,7 @@ func validateNSList(list []nbdns.NameServer) error { return nil } -func validateGroups(list []string, groups map[string]*nbgroup.Group) error { +func validateGroups(list []string, groups map[string]*types.Group) error { if len(list) == 0 { return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty") } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index b6a1a6484..0743db513 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -11,7 +11,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -844,12 +843,12 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup - newGroup1 := &nbgroup.Group{ + newGroup1 := &types.Group{ ID: group1ID, Name: group1ID, } - newGroup2 := &nbgroup.Group{ + newGroup2 := &types.Group{ ID: group2ID, Name: group2ID, } @@ -946,7 +945,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { var newNameServerGroupA *nbdns.NameServerGroup var newNameServerGroupB *nbdns.NameServerGroup - err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 7574d0397..72a39441e 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -24,7 +24,6 @@ import ( "github.com/netbirdio/netbird/management/proto" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" @@ -283,8 +282,8 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } var ( - group1 nbgroup.Group - group2 nbgroup.Group + group1 types.Group + group2 types.Group ) group1.ID = xid.New().String() @@ -751,7 +750,7 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou account.Policies = make([]*types.Policy, 0, groups) for i := 0; i < groups; i++ { groupID := fmt.Sprintf("group-%d", i) - group := &nbgroup.Group{ + group := &types.Group{ ID: groupID, Name: fmt.Sprintf("Group %d", i), } @@ -1286,7 +1285,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) require.NoError(t, err) - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", diff --git a/management/server/policy.go b/management/server/policy.go index 9f488c27e..8ae2f96d0 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -11,7 +11,6 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" ) @@ -239,7 +238,7 @@ func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureCh } // getValidGroupIDs filters and returns only the valid group IDs from the provided list. -func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string { +func getValidGroupIDs(groups map[string]*types.Group, groupIDs []string) []string { validIDs := make([]string, 0, len(groupIDs)) for _, id := range groupIDs { if _, exists := groups[id]; exists { diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 10e1397b9..fab738abe 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -10,7 +10,6 @@ import ( "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/types" @@ -60,7 +59,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -308,7 +307,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -583,7 +582,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -830,7 +829,7 @@ func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int { func TestPolicyAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 8565cc4f6..bad162f05 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -8,7 +8,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -122,7 +121,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -445,18 +444,18 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { account, err := initTestPostureChecksAccount(manager) require.NoError(t, err, "failed to init testing account") - groupA := &group.Group{ + groupA := &types.Group{ ID: "groupA", AccountID: account.Id, Peers: []string{"peer1"}, } - groupB := &group.Group{ + groupB := &types.Group{ ID: "groupB", AccountID: account.Id, Peers: []string{}, } - err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, []*group.Group{groupA, groupB}) + err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, []*types.Group{groupA, groupB}) require.NoError(t, err, "failed to save groups") postureCheckA := &posture.Checks{ diff --git a/management/server/route_test.go b/management/server/route_test.go index 040cb2c87..5e2e24611 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -15,7 +15,6 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" @@ -1096,7 +1095,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, account.Id) require.NoError(t, err) - var groupHA1, groupHA2 *nbgroup.Group + var groupHA1, groupHA2 *types.Group for _, group := range groups { switch group.Name { case routeGroupHA1: @@ -1204,7 +1203,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.Len(t, peer2Routes.Routes, 1, "we should receive one route") require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group") - newGroup := &nbgroup.Group{ + newGroup := &types.Group{ ID: xid.New().String(), Name: "peer1 group", Peers: []string{peer1ID}, @@ -1441,7 +1440,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou return nil, err } - newGroup := []*nbgroup.Group{ + newGroup := []*types.Group{ { ID: routeGroup1, Name: routeGroup1, @@ -1557,7 +1556,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "routingPeer1": { ID: "routingPeer1", Name: "RoutingPeer1", @@ -1911,7 +1910,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { account, err := initTestRouteAccount(t, manager) require.NoError(t, err, "failed to init testing account") - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -2107,7 +2106,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1ID}, @@ -2147,7 +2146,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1ID}, diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 2cc0718c2..f728db5d4 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -15,7 +15,6 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/types" ) @@ -31,7 +30,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "group_1", Name: "group_name_1", @@ -106,7 +105,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -115,7 +114,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, @@ -400,7 +399,7 @@ func TestSetupKey_Copy(t *testing.T) { func TestSetupKeyAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, diff --git a/management/server/store/file_store.go b/management/server/store/file_store.go index d3a101bbf..9127c2705 100644 --- a/management/server/store/file_store.go +++ b/management/server/store/file_store.go @@ -11,7 +11,6 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" @@ -148,7 +147,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { // Set API as issuer for groups which has not this field for _, group := range account.Groups { if group.Issued == "" { - group.Issued = nbgroup.GroupIssuedAPI + group.Issued = types.GroupIssuedAPI } } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 4ea1f5e4c..771a32aae 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -24,7 +24,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" - nbgroup "github.com/netbirdio/netbird/management/server/group" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -90,7 +89,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics t return nil, fmt.Errorf("migrate: %w", err) } err = db.AutoMigrate( - &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &nbgroup.Group{}, + &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, @@ -437,7 +436,7 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u } // SaveGroups saves the given list of groups to the database. -func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { +func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error { if len(groups) == 0 { return nil } @@ -575,8 +574,8 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre return users, nil } -func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) { - var groups []*nbgroup.Group +func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) { + var groups []*types.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -659,7 +658,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc } account.UsersG = nil - account.Groups = make(map[string]*nbgroup.Group, len(account.GroupsG)) + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) for _, group := range account.GroupsG { account.Groups[group.ID] = group.Copy() } @@ -1021,7 +1020,7 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string } func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { - var group nbgroup.Group + var group types.Group result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1046,7 +1045,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer } func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { - var group nbgroup.Group + var group types.Group result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1206,8 +1205,8 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength } // GetGroupByID retrieves a group by ID and account ID. -func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) { - var group *nbgroup.Group +func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) { + var group *types.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -1221,8 +1220,8 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt } // GetGroupByName retrieves a group by name and account ID. -func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) { - var group nbgroup.Group +func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) { + var group types.Group // TODO: This fix is accepted for now, but if we need to handle this more frequently // we may need to reconsider changing the types. @@ -1245,15 +1244,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren } // GetGroupsByIDs retrieves groups by their IDs and account ID. -func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) { - var groups []*nbgroup.Group +func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) { + var groups []*types.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store") } - groupsMap := make(map[string]*nbgroup.Group) + groupsMap := make(map[string]*types.Group) for _, group := range groups { groupsMap[group.ID] = group } @@ -1262,7 +1261,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren } // SaveGroup saves a group to the store. -func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { +func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error) @@ -1274,7 +1273,7 @@ func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, // DeleteGroup deletes a group from the database. func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID) + Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error) return status.Errorf(status.Internal, "failed to delete group from store") @@ -1290,7 +1289,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength // DeleteGroups deletes groups from the database. func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { result := s.db.Clauses(clause.Locking{Strength: string(strength)}). - Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) + Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) return status.Errorf(status.Internal, "failed to delete groups from store") diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 41d37f2a8..9bb7addcb 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -20,7 +20,6 @@ import ( "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -119,7 +118,7 @@ func runLargeTest(t *testing.T, store Store) { } account.Routes[route.ID] = route - group = &nbgroup.Group{ + group = &types.Group{ ID: fmt.Sprintf("group-id-%d", n), AccountID: account.Id, Name: fmt.Sprintf("group-id-%d", n), @@ -1201,7 +1200,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { t.Fatal(err) } - group := &nbgroup.Group{ + group := &types.Group{ ID: "group-id", AccountID: "account-id", Name: "group-name", @@ -1377,7 +1376,7 @@ func TestSqlStore_SaveGroup(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - group := &nbgroup.Group{ + group := &types.Group{ ID: "group-id", AccountID: accountID, Issued: "api", @@ -1398,7 +1397,7 @@ func TestSqlStore_SaveGroups(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - groups := []*nbgroup.Group{ + groups := []*types.Group{ { ID: "group-1", AccountID: accountID, @@ -2137,15 +2136,15 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty // addAllGroup to account object if it doesn't exist func addAllGroup(account *types.Account) error { if len(account.Groups) == 0 { - allGroup := &nbgroup.Group{ + allGroup := &types.Group{ ID: xid.New().String(), Name: "All", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, } for _, peer := range account.Peers { allGroup.Peers = append(allGroup.Peers, peer.ID) } - account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup} + account.Groups = map[string]*types.Group{allGroup.ID: allGroup} id := xid.New().String() diff --git a/management/server/store/store.go b/management/server/store/store.go index b244d186b..07fef6cfd 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -20,8 +20,6 @@ import ( "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/types" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" @@ -75,12 +73,12 @@ type Store interface { DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error - GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) - GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) - GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) - GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) - SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error - SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error + GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) + GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*types.Group, error) + GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) + GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) + SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error + SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error diff --git a/management/server/types/account.go b/management/server/types/account.go index 34da7b43b..281c8ea63 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -16,7 +16,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" - nbgroup "github.com/netbirdio/netbird/management/server/group" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" @@ -59,8 +58,8 @@ type Account struct { PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` Users map[string]*User `gorm:"-"` UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` - Groups map[string]*nbgroup.Group `gorm:"-"` - GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"` + Groups map[string]*Group `gorm:"-"` + GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` Routes map[route.ID]*route.Route `gorm:"-"` RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` @@ -214,7 +213,7 @@ func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain } // GetGroup returns a group by ID if exists, nil otherwise -func (a *Account) GetGroup(groupID string) *nbgroup.Group { +func (a *Account) GetGroup(groupID string) *Group { return a.Groups[groupID] } @@ -609,7 +608,7 @@ func (a *Account) FindUser(userID string) (*User, error) { } // FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found. -func (a *Account) FindGroupByName(groupName string) (*nbgroup.Group, error) { +func (a *Account) FindGroupByName(groupName string) (*Group, error) { for _, group := range a.Groups { if group.Name == groupName { return group, nil @@ -703,7 +702,7 @@ func (a *Account) Copy() *Account { setupKeys[id] = key.Copy() } - groups := map[string]*nbgroup.Group{} + groups := map[string]*Group{} for id, group := range a.Groups { groups[id] = group.Copy() } @@ -774,7 +773,7 @@ func (a *Account) Copy() *Account { } } -func (a *Account) GetGroupAll() (*nbgroup.Group, error) { +func (a *Account) GetGroupAll() (*Group, error) { for _, g := range a.Groups { if g.Name == "All" { return g, nil @@ -910,7 +909,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, all, err := a.GetGroupAll() if err != nil { log.WithContext(ctx).Errorf("failed to get group all: %v", err) - all = &nbgroup.Group{} + all = &Group{} } return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { diff --git a/management/server/group/group.go b/management/server/types/group.go similarity index 67% rename from management/server/group/group.go rename to management/server/types/group.go index 24c60d3ce..7ba4b8656 100644 --- a/management/server/group/group.go +++ b/management/server/types/group.go @@ -1,6 +1,8 @@ -package group +package types -import "github.com/netbirdio/netbird/management/server/integration_reference" +import ( + "github.com/netbirdio/netbird/management/server/integration_reference" +) const ( GroupIssuedAPI = "api" @@ -25,6 +27,9 @@ type Group struct { // Peers list of the group Peers []string `gorm:"serializer:json"` + // Resources contains a list of resources in that group + Resources []Resource `gorm:"serializer:json"` + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` } @@ -39,9 +44,11 @@ func (g *Group) Copy() *Group { Name: g.Name, Issued: g.Issued, Peers: make([]string, len(g.Peers)), + Resources: make([]Resource, len(g.Resources)), IntegrationReference: g.IntegrationReference, } copy(group.Peers, g.Peers) + copy(group.Resources, g.Resources) return group } @@ -81,3 +88,26 @@ func (g *Group) RemovePeer(peerID string) bool { } return false } + +// AddResource adds resource to Resources if not present, returning true if added. +func (g *Group) AddResource(resource Resource) bool { + for _, item := range g.Resources { + if item == resource { + return false + } + } + + g.Resources = append(g.Resources, resource) + return true +} + +// RemoveResource removes resource from Resources if present, returning true if removed. +func (g *Group) RemoveResource(resource Resource) bool { + for i, item := range g.Resources { + if item == resource { + g.Resources = append(g.Resources[:i], g.Resources[i+1:]...) + return true + } + } + return false +} diff --git a/management/server/group/group_test.go b/management/server/types/group_test.go similarity index 99% rename from management/server/group/group_test.go rename to management/server/types/group_test.go index cb002f8d9..12107c603 100644 --- a/management/server/group/group_test.go +++ b/management/server/types/group_test.go @@ -1,4 +1,4 @@ -package group +package types import ( "testing" diff --git a/management/server/user.go b/management/server/user.go index d2fa4434d..9fc2464de 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -13,7 +13,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbContext "github.com/netbirdio/netbird/management/server/context" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -1143,8 +1142,8 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun } // updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. -func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd, - groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) { +func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, + groupsToRemove []string) (groupsToUpdate []*types.Group, err error) { if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { return @@ -1177,7 +1176,7 @@ func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[strin } // addUserPeersToGroup adds the user's peers to the group. -func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { +func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) { groupPeers := make(map[string]struct{}, len(group.Peers)) for _, pid := range group.Peers { groupPeers[pid] = struct{}{} @@ -1194,7 +1193,7 @@ func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) } // removeUserPeersFromGroup removes user's peers from the group. -func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { +func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) { // skip removing peers from group All if group.Name == "All" { return diff --git a/management/server/user_test.go b/management/server/user_test.go index 79e356f94..75d88f9c8 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -11,7 +11,6 @@ import ( cacheStore "github.com/eko/gocache/v3/store" "github.com/google/go-cmp/cmp" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -1365,7 +1364,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { // account groups propagation is enabled manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID},