diff --git a/management/server/account.go b/management/server/account.go index a5b1a7070..19ff96cfb 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -105,7 +105,7 @@ type AccountManager interface { DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error - GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) + GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error @@ -438,7 +438,7 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con return nil } -func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error { +func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { if newSettings.PeerInactivityExpirationEnabled { if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration @@ -663,7 +663,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI return "", status.Errorf(status.NotFound, "no valid userID provided") } - accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID) + accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) @@ -1448,7 +1448,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return "", err } - userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err @@ -1495,7 +1495,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont } func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { - userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 41c1d5577..3d6d01434 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -10,7 +10,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -122,7 +121,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. } func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { - peers, err := e.store.GetAllEphemeralPeers(ctx, LockingStrengthShare) + peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthShare) if err != nil { log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err) return diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index ce852fdc7..df8fe98c3 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -16,7 +16,7 @@ type MockStore struct { account *types.Account } -func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ LockingStrength) ([]*nbpeer.Peer, error) { +func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ store.LockingStrength) ([]*nbpeer.Peer, error) { var peers []*nbpeer.Peer for _, v := range s.account.Peers { if v.Ephemeral { diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 18557cca0..09d63ea6f 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -72,8 +72,13 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, } dnsDomain := h.accountManager.GetDNSDomain() - allGroups, _ := h.accountManager.GetAllGroups(ctx, accountID, userID) - groupsInfo := groups.ToGroupsInfo(allGroups, peerID) + groupsMap := map[string]*types.Group{} + grps, _ := h.accountManager.GetAllGroups(ctx, accountID, userID) + for _, group := range grps { + groupsMap[group.ID] = group + } + + groupsInfo := groups.ToGroupsInfo(groupsMap, peerID) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { @@ -122,7 +127,13 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri util.WriteError(ctx, err, w) return } - groupMinimumInfo := groups.ToGroupsInfo(peerGroups, peer.ID) + + groupsMap := map[string]*types.Group{} + for _, group := range peerGroups { + groupsMap[group.ID] = group + } + + groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { @@ -193,7 +204,11 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - allGroups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + groupsMap := map[string]*types.Group{} + grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + for _, group := range grps { + groupsMap[group.ID] = group + } respBody := make([]*api.PeerBatch, 0, len(peers)) for _, peer := range peers { @@ -202,7 +217,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - groupMinimumInfo := groups.ToGroupsInfo(allGroups, peer.ID) + groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } @@ -314,32 +329,6 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toGroupsInfo(groups []*nbgroup.Group, peerID string) []api.GroupMinimum { - groupsInfo := []api.GroupMinimum{} - groupsChecked := make(map[string]struct{}) - - for _, group := range groups { - _, ok := groupsChecked[group.ID] - if ok { - continue - } - - groupsChecked[group.ID] = struct{}{} - for _, pk := range group.Peers { - if pk == peerID { - info := api.GroupMinimum{ - Id: group.ID, - Name: group.Name, - PeersCount: len(group.Peers), - } - groupsInfo = append(groupsInfo, info) - break - } - } - } - return groupsInfo -} - func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index baab39ea9..16065a677 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -128,12 +128,12 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { return peers, nil }, - GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) { + GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { peersID := make([]string, len(peers)) for _, peer := range peers { peersID = append(peersID, peer.ID) } - return []*nbgroup.Group{ + return []*types.Group{ { ID: "group1", AccountID: accountID, @@ -149,10 +149,10 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetAccountFunc: func(ctx context.Context, accountID string) (*server.Account, error) { + GetAccountFunc: func(ctx context.Context, accountID string) (*types.Account, error) { return account, nil }, - GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { return account, nil }, HasConnectedChannelFunc: func(peerID string) bool { diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 56c4e3e6f..9dad6fcd7 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -4,8 +4,6 @@ import ( "context" "errors" - nbgroup "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" @@ -80,11 +78,11 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { var err error - var groups []*nbgroup.Group + var groups []*types.Group var peers []*nbpeer.Peer - var settings *Settings + var settings *types.Settings - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return err @@ -102,7 +100,7 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI return nil, err } - groupsMap := make(map[string]*nbgroup.Group, len(groups)) + groupsMap := make(map[string]*types.Group, len(groups)) for _, group := range groups { groupsMap[group.ID] = group } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 641254e4f..c8e42d20a 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -47,7 +47,7 @@ type MockAccountManager struct { DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error - GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*group.Group, error) + GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) @@ -843,7 +843,7 @@ func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) } // GetPeerGroups mocks GetPeerGroups of the AccountManager interface -func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*group.Group, error) { +func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { if am.GetPeerGroupsFunc != nil { return am.GetPeerGroupsFunc(ctx, accountID, peerID) } diff --git a/management/server/peer.go b/management/server/peer.go index 16c824ccf..c09e3d051 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -12,8 +12,6 @@ import ( "time" "github.com/netbirdio/netbird/management/server/geolocation" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/netbirdio/netbird/management/server/util" "github.com/rs/xid" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" @@ -122,11 +120,11 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID // MarkPeerConnected marks peer as connected (true) or disconnected (false) func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { var peer *nbpeer.Peer - var settings *Settings + var settings *types.Settings var expired bool var err error - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey) if err != nil { return err @@ -163,7 +161,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return nil } -func updatePeerStatusAndLocation(ctx context.Context, geo *geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { +func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus newStatus.LastSeen = time.Now().UTC() @@ -215,7 +213,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } var peer *nbpeer.Peer - var settings *Settings + var settings *types.Settings var peerGroupList []string var requiresPeerUpdates bool var peerLabelChanged bool @@ -223,7 +221,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user var loginExpirationChanged bool var inactivityExpirationChanged bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, update.ID) if err != nil { return err @@ -348,7 +346,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer var updateAccountPeers bool var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID) if err != nil { return err @@ -575,7 +573,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } } - err = transaction.AddPeerToAccount(ctx, LockingStrengthUpdate, newPeer) + err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer) if err != nil { return fmt.Errorf("failed to add peer to account: %w", err) } @@ -840,14 +838,14 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) unlockPeer = nil if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) } // getPeerPostureChecks returns the posture checks for the peer. -func getPeerPostureChecks(ctx context.Context, transaction Store, accountID, peerID string) ([]*posture.Checks, error) { +func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) { policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err @@ -872,7 +870,7 @@ func getPeerPostureChecks(ctx context.Context, transaction Store, accountID, pee peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...) } - peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, peerPostureChecksIDs) + peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, peerPostureChecksIDs) if err != nil { return nil, err } @@ -881,7 +879,7 @@ func getPeerPostureChecks(ctx context.Context, transaction Store, accountID, pee } // processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks. -func processPeerPostureChecks(ctx context.Context, transaction Store, policy *Policy, accountID, peerID string) ([]string, error) { +func processPeerPostureChecks(ctx context.Context, transaction store.Store, policy *types.Policy, accountID, peerID string) ([]string, error) { for _, rule := range policy.Rules { if !rule.Enabled { continue @@ -980,7 +978,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact return err } - err = transaction.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.peer.GetLastLogin()) + err = transaction.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.GetLastLogin()) if err != nil { return err } @@ -1142,7 +1140,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account // If there is no peer that expires this function returns false and a duration of 0. // This function only considers peers that haven't been expired yet and that are connected. func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { - peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID) + peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err) return 0, false @@ -1186,7 +1184,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco // If there is no peer that expires this function returns false and a duration of 0. // This function only considers peers that haven't been expired yet and that are not connected. func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { - peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID) + peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err) return 0, false @@ -1227,7 +1225,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte // getExpiredPeers returns peers that have been expired. func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { - peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID) + peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -1250,7 +1248,7 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID // getInactivePeers returns peers that have been expired by inactivity func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { - peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID) + peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -1272,18 +1270,18 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID } // GetPeerGroups returns groups that the peer is part of. -func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) { +func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { return getPeerGroups(ctx, am.Store, accountID, peerID) } // getPeerGroups returns the IDs of the groups that the peer is part of. -func getPeerGroups(ctx context.Context, transaction Store, accountID, peerID string) ([]*nbgroup.Group, error) { +func getPeerGroups(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*types.Group, error) { groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } - peerGroups := make([]*nbgroup.Group, 0) + peerGroups := make([]*types.Group, 0) for _, group := range groups { if slices.Contains(group.Peers, peerID) { peerGroups = append(peerGroups, group) @@ -1294,7 +1292,7 @@ func getPeerGroups(ctx context.Context, transaction Store, accountID, peerID str } // getPeerGroupIDs returns the IDs of the groups that the peer is part of. -func getPeerGroupIDs(ctx context.Context, transaction Store, accountID string, peerID string) ([]string, error) { +func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { groups, err := getPeerGroups(ctx, transaction, accountID, peerID) if err != nil { return nil, err @@ -1308,13 +1306,13 @@ func getPeerGroupIDs(ctx context.Context, transaction Store, accountID string, p return groupIDs, err } -func getPeerDNSLabels(ctx context.Context, transaction Store, accountID string) (lookupMap, error) { +func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) { dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } - existingLabels := make(lookupMap) + existingLabels := make(types.LookupMap) for _, label := range dnsLabels { existingLabels[label] = struct{}{} } @@ -1323,7 +1321,7 @@ func getPeerDNSLabels(ctx context.Context, transaction Store, accountID string) // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. -func isPeerInActiveGroup(ctx context.Context, transaction Store, accountID, peerID string) (bool, error) { +func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) { peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID) if err != nil { return false, err @@ -1333,7 +1331,7 @@ func isPeerInActiveGroup(ctx context.Context, transaction Store, accountID, peer // deletePeers deletes all specified peers and sends updates to the remote peers. // Returns a slice of functions to save events after successful peer deletion. -func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { +func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { var peerDeletedEvents []func() for _, peer := range peers { @@ -1362,7 +1360,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction Sto FirewallRulesIsEmpty: true, }, }, - NetworkMap: &NetworkMap{}, + NetworkMap: &types.NetworkMap{}, }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) peerDeletedEvents = append(peerDeletedEvents, func() { diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 1145e928f..630d4e426 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/google/uuid" + "github.com/netbirdio/netbird/management/server/util" "github.com/rs/xid" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -2602,7 +2603,7 @@ func TestSqlStore_AddPeerToAccount(t *testing.T) { SSHEnabled: false, LoginExpirationEnabled: true, InactivityExpirationEnabled: false, - LastLogin: time.Now().UTC(), + LastLogin: util.ToPtr(time.Now().UTC()), CreatedAt: time.Now().UTC(), Ephemeral: true, } @@ -2623,7 +2624,7 @@ func TestSqlStore_AddPeerToAccount(t *testing.T) { assert.Equal(t, peer.SSHEnabled, storedPeer.SSHEnabled) assert.Equal(t, peer.LoginExpirationEnabled, storedPeer.LoginExpirationEnabled) assert.Equal(t, peer.InactivityExpirationEnabled, storedPeer.InactivityExpirationEnabled) - assert.WithinDurationf(t, peer.LastLogin, storedPeer.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal") + assert.WithinDurationf(t, peer.GetLastLogin(), storedPeer.GetLastLogin().UTC(), time.Millisecond, "LastLogin should be equal") assert.WithinDurationf(t, peer.CreatedAt, storedPeer.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal") assert.Equal(t, peer.Ephemeral, storedPeer.Ephemeral) assert.Equal(t, peer.Status.Connected, storedPeer.Status.Connected) diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql index 9e0a605fe..5990a0625 100644 --- a/management/server/testdata/store_with_expired_peers.sql +++ b/management/server/testdata/store_with_expired_peers.sql @@ -1,6 +1,6 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`inactivity_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); diff --git a/management/server/user.go b/management/server/user.go index b189b2c85..17770a423 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -979,7 +979,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou peerIDs = append(peerIDs, peer.ID) peer.MarkLoginExpired(true) - if err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *peer.Status); err != nil { + if err := am.Store.SavePeerStatus(ctx, store.LockingStrengthUpdate, accountID, peer.ID, *peer.Status); err != nil { return err } am.StoreEvent(