diff --git a/management/server/account_test.go b/management/server/account_test.go index a0eff239b..de1194f86 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1666,7 +1666,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil) + + account, err = manager.GetAccountByUserOrAccountID(userID, "", "") + require.NoError(t, err, "unable to get the account") + err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ PeerLoginExpiration: time.Hour, @@ -1732,8 +1735,10 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. }, } + account, err = manager.GetAccountByUserOrAccountID(userID, "", "") + require.NoError(t, err, "unable to get the account") // when we mark peer as connected, the peer login expiration routine should trigger - err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil) + err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") failed := waitTimeout(wg, time.Second) @@ -1756,7 +1761,10 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil) + + account, err = manager.GetAccountByUserOrAccountID(userID, "", "") + require.NoError(t, err, "unable to get the account") + err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") wg := &sync.WaitGroup{} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 8687937dc..3be2dae18 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -22,80 +22,93 @@ type MockAccountManager struct { GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) - GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) - GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) - GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error) - ListUsersFunc func(accountID string) ([]*server.User, error) - GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error) - MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error - DeletePeerFunc func(accountID, peerKey, userID string) error - GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) - GetPeerNetworkFunc func(peerKey string) (*server.Network, error) - AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error) - GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error) - GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error) - GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error) - SaveGroupFunc func(accountID, userID string, group *group.Group) error - DeleteGroupFunc func(accountID, userId, groupID string) error - ListGroupsFunc func(accountID string) ([]*group.Group, error) - GroupAddPeerFunc func(accountID, groupID, peerID string) error - GroupDeletePeerFunc func(accountID, groupID, peerID string) error - DeleteRuleFunc func(accountID, ruleID, userID string) error - GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(accountID, userID string, policy *server.Policy) error - DeletePolicyFunc func(accountID, policyID, userID string) error - ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error) - GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) - GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) - MarkPATUsedFunc func(pat string) error - UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error - UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error - UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) - GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error) - SaveRouteFunc func(accountID, userID string, route *route.Route) error - DeleteRouteFunc func(accountID, routeID, userID string) error - ListRoutesFunc func(accountID, userID string) ([]*route.Route, error) - SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) - ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) - SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) - SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) - DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error - CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) - DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error - GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) - GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) - GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) - CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) - SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error - DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error - ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error) - CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) - GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) - CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error - DeleteAccountFunc func(accountID, userID string) error - GetDNSDomainFunc func() string - StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) - GetEventsFunc func(accountID, userID string) ([]*activity.Event, error) - GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error) - SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error - GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) - LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error) - SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error) - InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error - GetAllConnectedPeersFunc func() (map[string]struct{}, error) - HasConnectedChannelFunc func(peerID string) bool - GetExternalCacheManagerFunc func() server.ExternalCacheManager - GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error - DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error - ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error) - GetIdpManagerFunc func() idp.Manager + GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) + GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) + GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error) + ListUsersFunc func(accountID string) ([]*server.User, error) + GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error) + MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error + SyncAndMarkPeerFunc func(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) + DeletePeerFunc func(accountID, peerKey, userID string) error + GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) + GetPeerNetworkFunc func(peerKey string) (*server.Network, error) + AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error) + GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error) + GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error) + GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error) + SaveGroupFunc func(accountID, userID string, group *group.Group) error + DeleteGroupFunc func(accountID, userId, groupID string) error + ListGroupsFunc func(accountID string) ([]*group.Group, error) + GroupAddPeerFunc func(accountID, groupID, peerID string) error + GroupDeletePeerFunc func(accountID, groupID, peerID string) error + DeleteRuleFunc func(accountID, ruleID, userID string) error + GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error) + SavePolicyFunc func(accountID, userID string, policy *server.Policy) error + DeletePolicyFunc func(accountID, policyID, userID string) error + ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error) + GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) + GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) + MarkPATUsedFunc func(pat string) error + UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error + UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error + UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) + CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) + GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error) + SaveRouteFunc func(accountID, userID string, route *route.Route) error + DeleteRouteFunc func(accountID, routeID, userID string) error + ListRoutesFunc func(accountID, userID string) ([]*route.Route, error) + SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) + ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) + SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) + SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) + DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error + CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) + DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error + GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) + GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) + GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) + CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) + SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error + DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error + ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error) + CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) + GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) + CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error + DeleteAccountFunc func(accountID, userID string) error + GetDNSDomainFunc func() string + StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) + GetEventsFunc func(accountID, userID string) ([]*activity.Event, error) + GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error) + SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error + GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error) + UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) + LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error) + SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error) + InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error + GetAllConnectedPeersFunc func() (map[string]struct{}, error) + HasConnectedChannelFunc func(peerID string) bool + GetExternalCacheManagerFunc func() server.ExternalCacheManager + GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error) + SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error + DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error + ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error) + GetIdpManagerFunc func() idp.Manager UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error GroupValidationFunc func(accountId string, groups []string) (bool, error) } +func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) { + if am.SyncAndMarkPeerFunc != nil { + return am.SyncAndMarkPeerFunc(peerPubKey, realIP) + } + return nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") +} + +func (am *MockAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error { + // TODO implement me + panic("implement me") +} + func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { approvedPeers := make(map[string]struct{}) for id := range account.Peers { @@ -180,7 +193,7 @@ func (am *MockAccountManager) GetAccountByUserOrAccountID( } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error { +func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *server.Account) error { if am.MarkPeerConnectedFunc != nil { return am.MarkPeerConnectedFunc(peerKey, connected, realIP) } @@ -626,7 +639,7 @@ func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, * } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error) { +func (am *MockAccountManager) SyncPeer(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error) { if am.SyncPeerFunc != nil { return am.SyncPeerFunc(sync) }