diff --git a/go.mod b/go.mod index 06dc921f1..2f76c0766 100644 --- a/go.mod +++ b/go.mod @@ -56,6 +56,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 + github.com/jackc/pgx/v5 v5.5.5 github.com/libdns/route53 v1.5.0 github.com/libp2p/go-netroute v0.2.1 github.com/mdlayher/socket v0.5.1 @@ -183,7 +184,6 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.5.5 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect github.com/jinzhu/inflection v1.0.0 // indirect diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 5953671a4..090f52920 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2,6 +2,7 @@ package store import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -15,6 +16,8 @@ import ( "sync" "time" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -46,6 +49,11 @@ const ( accountAndIDsQueryCondition = "account_id = ? AND id IN ?" accountIDCondition = "account_id = ?" peerNotFoundFMT = "peer %s not found" + + pgMaxConnections = 30 + pgMinConnections = 5 + pgMaxConnLifetime = 60 * time.Minute + pgHealthCheckPeriod = 1 * time.Minute ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -55,6 +63,7 @@ type SqlStore struct { metrics telemetry.AppMetrics installationPK int storeEngine types.Engine + pool *pgxpool.Pool } type installation struct { @@ -774,6 +783,13 @@ func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types. } func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { + if s.pool != nil { + return s.getAccountPgx(ctx, accountID) + } + return s.getAccountGorm(ctx, accountID) +} + +func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { elapsed := time.Since(start) @@ -784,9 +800,19 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc var account types.Account result := s.db.Model(&account). - Omit("GroupsG"). - Preload("UsersG.PATsG"). // have to be specifies as this is nester reference - Preload(clause.Associations). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload("Policies.Rules"). + Preload("SetupKeysG"). + Preload("PeersG"). + Preload("UsersG"). + Preload("GroupsG.GroupPeers"). + Preload("RoutesG"). + Preload("NameServerGroupsG"). + Preload("PostureChecks"). + Preload("Networks"). + Preload("NetworkRouters"). + Preload("NetworkResources"). + Preload("Onboarding"). Take(&account, idQueryCondition, accountID) if result.Error != nil { log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) @@ -796,71 +822,1040 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc return nil, status.NewGetAccountFromStoreError(result.Error) } - // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us - for i, policy := range account.Policies { - var rules []*types.PolicyRule - err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error - if err != nil { - return nil, status.Errorf(status.NotFound, "rule not found") - } - account.Policies[i].Rules = rules - } - account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) for _, key := range account.SetupKeysG { - account.SetupKeys[key.Key] = key.Copy() + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } + account.SetupKeys[key.Key] = &key } account.SetupKeysG = nil account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) for _, peer := range account.PeersG { - account.Peers[peer.ID] = peer.Copy() + account.Peers[peer.ID] = &peer } account.PeersG = nil - account.Users = make(map[string]*types.User, len(account.UsersG)) for _, user := range account.UsersG { user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) for _, pat := range user.PATsG { - user.PATs[pat.ID] = pat.Copy() + pat.UserID = "" + user.PATs[pat.ID] = &pat } - account.Users[user.Id] = user.Copy() + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } + account.Users[user.Id] = &user + user.PATsG = nil } account.UsersG = nil - account.Groups = make(map[string]*types.Group, len(account.GroupsG)) for _, group := range account.GroupsG { - account.Groups[group.ID] = group.Copy() + group.Peers = make([]string, len(group.GroupPeers)) + for i, gp := range group.GroupPeers { + group.Peers[i] = gp.PeerID + } + if group.Resources == nil { + group.Resources = []types.Resource{} + } + account.Groups[group.ID] = group } account.GroupsG = nil - var groupPeers []types.GroupPeer - s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). - Find(&groupPeers) - for _, groupPeer := range groupPeers { - if group, ok := account.Groups[groupPeer.GroupID]; ok { - group.Peers = append(group.Peers, groupPeer.PeerID) - } else { - log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = &route + } + account.RoutesG = nil + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } + account.NameServerGroups[ns.ID] = &ns + } + account.NameServerGroupsG = nil + return &account, nil +} + +func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.Account, error) { + account, err := s.getAccount(ctx, accountID) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + errChan := make(chan error, 12) + + wg.Add(1) + go func() { + defer wg.Done() + keys, err := s.getSetupKeys(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + peers, err := s.getPeers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + users, err := s.getUsers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + groups, err := s.getGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + policies, err := s.getPolicies(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + routes, err := s.getRoutes(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + nsgs, err := s.getNameServerGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + checks, err := s.getPostureChecks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + networks, err := s.getNetworks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Networks = networks + }() + + wg.Add(1) + go func() { + defer wg.Done() + routers, err := s.getNetworkRouters(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + resources, err := s.getNetworkResources(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.getAccountOnboarding(ctx, accountID, account) + if err != nil { + errChan <- err + return + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e } } - account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) - for _, route := range account.RoutesG { - account.Routes[route.ID] = route.Copy() + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + var err error + pats, err = s.getPersonalAccessTokens(ctx, userIDs) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + var err error + rules, err = s.getPolicyRules(ctx, policyIDs) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + var err error + groupPeers, err = s.getGroupPeers(ctx, groupIDs) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" + } + + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) + } + + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key + } + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat + } + } + account.Users[user.Id] = user + } + + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules + } + } + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs + } + account.Groups[group.ID] = group + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route } - account.RoutesG = nil account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) - for _, ns := range account.NameServerGroupsG { - account.NameServerGroups[ns.ID] = ns.Copy() + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil account.NameServerGroupsG = nil account.InitOnce() + return account, nil +} + +func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Account, error) { + var account types.Account + account.Network = &types.Network{} + const accountQuery = ` + SELECT + id, created_by, created_at, domain, domain_category, is_domain_primary_account, + -- Embedded Network + network_identifier, network_net, network_dns, network_serial, + -- Embedded DNSSettings + dns_settings_disabled_management_groups, + -- Embedded Settings + settings_peer_login_expiration_enabled, settings_peer_login_expiration, + settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration, + settings_regular_users_view_blocked, settings_groups_propagation_enabled, + settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups, + settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range, + settings_lazy_connection_enabled, + -- Embedded ExtraSettings + settings_extra_peer_approval_enabled, settings_extra_user_approval_required, + settings_extra_integrated_validator, settings_extra_integrated_validator_groups + FROM accounts WHERE id = $1` + + var networkNet, dnsSettingsDisabledGroups []byte + var ( + sPeerLoginExpirationEnabled sql.NullBool + sPeerLoginExpiration sql.NullInt64 + sPeerInactivityExpirationEnabled sql.NullBool + sPeerInactivityExpiration sql.NullInt64 + sRegularUsersViewBlocked sql.NullBool + sGroupsPropagationEnabled sql.NullBool + sJWTGroupsEnabled sql.NullBool + sJWTGroupsClaimName sql.NullString + sJWTAllowGroups []byte + sRoutingPeerDNSResolutionEnabled sql.NullBool + sDNSDomain sql.NullString + sNetworkRange []byte + sLazyConnectionEnabled sql.NullBool + sExtraPeerApprovalEnabled sql.NullBool + sExtraUserApprovalRequired sql.NullBool + sExtraIntegratedValidator sql.NullString + sExtraIntegratedValidatorGroups []byte + ) + + err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan( + &account.Id, &account.CreatedBy, &account.CreatedAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount, + &account.Network.Identifier, &networkNet, &account.Network.Dns, &account.Network.Serial, + &dnsSettingsDisabledGroups, + &sPeerLoginExpirationEnabled, &sPeerLoginExpiration, + &sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration, + &sRegularUsersViewBlocked, &sGroupsPropagationEnabled, + &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups, + &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange, + &sLazyConnectionEnabled, + &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired, + &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups, + ) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, errors.New("account not found") + } + return nil, err + } + + _ = json.Unmarshal(networkNet, &account.Network.Net) + _ = json.Unmarshal(dnsSettingsDisabledGroups, &account.DNSSettings.DisabledManagementGroups) + + account.Settings = &types.Settings{Extra: &types.ExtraSettings{}} + if sPeerLoginExpirationEnabled.Valid { + account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool + } + if sPeerLoginExpiration.Valid { + account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64) + } + if sPeerInactivityExpirationEnabled.Valid { + account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool + } + if sPeerInactivityExpiration.Valid { + account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64) + } + if sRegularUsersViewBlocked.Valid { + account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool + } + if sGroupsPropagationEnabled.Valid { + account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool + } + if sJWTGroupsEnabled.Valid { + account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool + } + if sJWTGroupsClaimName.Valid { + account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String + } + if sRoutingPeerDNSResolutionEnabled.Valid { + account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool + } + if sDNSDomain.Valid { + account.Settings.DNSDomain = sDNSDomain.String + } + if sLazyConnectionEnabled.Valid { + account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool + } + if sJWTAllowGroups != nil { + _ = json.Unmarshal(sJWTAllowGroups, &account.Settings.JWTAllowGroups) + } + if sNetworkRange != nil { + _ = json.Unmarshal(sNetworkRange, &account.Settings.NetworkRange) + } + + if sExtraPeerApprovalEnabled.Valid { + account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool + } + if sExtraUserApprovalRequired.Valid { + account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool + } + if sExtraIntegratedValidator.Valid { + account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String + } + if sExtraIntegratedValidatorGroups != nil { + _ = json.Unmarshal(sExtraIntegratedValidatorGroups, &account.Settings.Extra.IntegratedValidatorGroups) + } return &account, nil } +func (s *SqlStore) getSetupKeys(ctx context.Context, accountID string) ([]types.SetupKey, error) { + const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { + var sk types.SetupKey + var autoGroups []byte + var expiresAt, updatedAt, lastUsed sql.NullTime + var revoked, ephemeral, allowExtraDNSLabels sql.NullBool + var usedTimes, usageLimit sql.NullInt64 + + err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels) + + if err == nil { + if expiresAt.Valid { + sk.ExpiresAt = &expiresAt.Time + } + if updatedAt.Valid { + sk.UpdatedAt = updatedAt.Time + if sk.UpdatedAt.IsZero() { + sk.UpdatedAt = sk.CreatedAt + } + } + if lastUsed.Valid { + sk.LastUsed = &lastUsed.Time + } + if revoked.Valid { + sk.Revoked = revoked.Bool + } + if usedTimes.Valid { + sk.UsedTimes = int(usedTimes.Int64) + } + if usageLimit.Valid { + sk.UsageLimit = int(usageLimit.Int64) + } + if ephemeral.Valid { + sk.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &sk.AutoGroups) + } else { + sk.AutoGroups = []string{} + } + } + return sk, err + }) + if err != nil { + return nil, err + } + return keys, nil +} + +func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Peer, error) { + const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, location_geo_name_id FROM peers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { + var p nbpeer.Peer + p.Status = &nbpeer.PeerStatus{} + var lastLogin sql.NullTime + var sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool + var peerStatusLastSeen sql.NullTime + var peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool + var ip, extraDNS, netAddr, env, flags, files, connIP []byte + + err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &p.CreatedAt, &ephemeral, &extraDNS, &allowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) + + if err == nil { + if lastLogin.Valid { + p.LastLogin = &lastLogin.Time + } + if sshEnabled.Valid { + p.SSHEnabled = sshEnabled.Bool + } + if loginExpirationEnabled.Valid { + p.LoginExpirationEnabled = loginExpirationEnabled.Bool + } + if inactivityExpirationEnabled.Valid { + p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool + } + if ephemeral.Valid { + p.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if peerStatusLastSeen.Valid { + p.Status.LastSeen = peerStatusLastSeen.Time + } + if peerStatusConnected.Valid { + p.Status.Connected = peerStatusConnected.Bool + } + if peerStatusLoginExpired.Valid { + p.Status.LoginExpired = peerStatusLoginExpired.Bool + } + if peerStatusRequiresApproval.Valid { + p.Status.RequiresApproval = peerStatusRequiresApproval.Bool + } + + if ip != nil { + _ = json.Unmarshal(ip, &p.IP) + } + if extraDNS != nil { + _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) + } + if netAddr != nil { + _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses) + } + if env != nil { + _ = json.Unmarshal(env, &p.Meta.Environment) + } + if flags != nil { + _ = json.Unmarshal(flags, &p.Meta.Flags) + } + if files != nil { + _ = json.Unmarshal(files, &p.Meta.Files) + } + if connIP != nil { + _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) + } + } + return p, err + }) + if err != nil { + return nil, err + } + return peers, nil +} + +func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) { + const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { + var u types.User + var autoGroups []byte + var lastLogin sql.NullTime + var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool + err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) + if err == nil { + if lastLogin.Valid { + u.LastLogin = &lastLogin.Time + } + if isServiceUser.Valid { + u.IsServiceUser = isServiceUser.Bool + } + if nonDeletable.Valid { + u.NonDeletable = nonDeletable.Bool + } + if blocked.Valid { + u.Blocked = blocked.Bool + } + if pendingApproval.Valid { + u.PendingApproval = pendingApproval.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &u.AutoGroups) + } else { + u.AutoGroups = []string{} + } + } + return u, err + }) + if err != nil { + return nil, err + } + return users, nil +} + +func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) { + const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { + var g types.Group + var resources []byte + var refID sql.NullInt64 + var refType sql.NullString + err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) + if err == nil { + if refID.Valid { + g.IntegrationReference.ID = int(refID.Int64) + } + if refType.Valid { + g.IntegrationReference.IntegrationType = refType.String + } + if resources != nil { + _ = json.Unmarshal(resources, &g.Resources) + } else { + g.Resources = []types.Resource{} + } + g.GroupPeers = []types.GroupPeer{} + g.Peers = []string{} + } + return &g, err + }) + if err != nil { + return nil, err + } + return groups, nil +} + +func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) { + const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { + var p types.Policy + var checks []byte + var enabled sql.NullBool + err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) + if err == nil { + if enabled.Valid { + p.Enabled = enabled.Bool + } + if checks != nil { + _ = json.Unmarshal(checks, &p.SourcePostureChecks) + } + } + return &p, err + }) + if err != nil { + return nil, err + } + return policies, nil +} + +func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) { + const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { + var r route.Route + var network, domains, peerGroups, groups, accessGroups []byte + var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) + if err == nil { + if keepRoute.Valid { + r.KeepRoute = keepRoute.Bool + } + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if skipAutoApply.Valid { + r.SkipAutoApply = skipAutoApply.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if network != nil { + _ = json.Unmarshal(network, &r.Network) + } + if domains != nil { + _ = json.Unmarshal(domains, &r.Domains) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + if groups != nil { + _ = json.Unmarshal(groups, &r.Groups) + } + if accessGroups != nil { + _ = json.Unmarshal(accessGroups, &r.AccessControlGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + return routes, nil +} + +func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) { + const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { + var n nbdns.NameServerGroup + var ns, groups, domains []byte + var primary, enabled, searchDomainsEnabled sql.NullBool + err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) + if err == nil { + if primary.Valid { + n.Primary = primary.Bool + } + if enabled.Valid { + n.Enabled = enabled.Bool + } + if searchDomainsEnabled.Valid { + n.SearchDomainsEnabled = searchDomainsEnabled.Bool + } + if ns != nil { + _ = json.Unmarshal(ns, &n.NameServers) + } else { + n.NameServers = []nbdns.NameServer{} + } + if groups != nil { + _ = json.Unmarshal(groups, &n.Groups) + } else { + n.Groups = []string{} + } + if domains != nil { + _ = json.Unmarshal(domains, &n.Domains) + } else { + n.Domains = []string{} + } + } + return n, err + }) + if err != nil { + return nil, err + } + return nsgs, nil +} + +func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) { + const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { + var c posture.Checks + var checksDef []byte + err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) + if err == nil && checksDef != nil { + _ = json.Unmarshal(checksDef, &c.Checks) + } + return &c, err + }) + if err != nil { + return nil, err + } + return checks, nil +} + +func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) { + const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network]) + if err != nil { + return nil, err + } + result := make([]*networkTypes.Network, len(networks)) + for i := range networks { + result[i] = &networks[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) { + const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { + var r routerTypes.NetworkRouter + var peerGroups []byte + var masquerade, enabled sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) + if err == nil { + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*routerTypes.NetworkRouter, len(routers)) + for i := range routers { + result[i] = &routers[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) { + const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { + var r resourceTypes.NetworkResource + var prefix []byte + var enabled sql.NullBool + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if prefix != nil { + _ = json.Unmarshal(prefix, &r.Prefix) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*resourceTypes.NetworkResource, len(resources)) + for i := range resources { + result[i] = &resources[i] + } + return result, nil +} + +func (s *SqlStore) getAccountOnboarding(ctx context.Context, accountID string, account *types.Account) error { + const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` + var onboardingFlowPending, signupFormPending sql.NullBool + err := s.pool.QueryRow(ctx, query, accountID).Scan( + &account.Onboarding.AccountID, + &onboardingFlowPending, + &signupFormPending, + &account.Onboarding.CreatedAt, + &account.Onboarding.UpdatedAt, + ) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return err + } + if onboardingFlowPending.Valid { + account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool + } + if signupFormPending.Valid { + account.Onboarding.SignupFormPending = signupFormPending.Bool + } + return nil +} + +func (s *SqlStore) getPersonalAccessTokens(ctx context.Context, userIDs []string) ([]types.PersonalAccessToken, error) { + if len(userIDs) == 0 { + return nil, nil + } + const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, userIDs) + if err != nil { + return nil, err + } + pats, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { + var pat types.PersonalAccessToken + var expirationDate, lastUsed sql.NullTime + err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &pat.CreatedAt, &lastUsed) + if err == nil { + if expirationDate.Valid { + pat.ExpirationDate = &expirationDate.Time + } + if lastUsed.Valid { + pat.LastUsed = &lastUsed.Time + } + } + return pat, err + }) + if err != nil { + return nil, err + } + return pats, nil +} + +func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*types.PolicyRule, error) { + if len(policyIDs) == 0 { + return nil, nil + } + const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, policyIDs) + if err != nil { + return nil, err + } + rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { + var r types.PolicyRule + var dest, destRes, sources, sourceRes, ports, portRanges []byte + var enabled, bidirectional sql.NullBool + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if bidirectional.Valid { + r.Bidirectional = bidirectional.Bool + } + if dest != nil { + _ = json.Unmarshal(dest, &r.Destinations) + } + if destRes != nil { + _ = json.Unmarshal(destRes, &r.DestinationResource) + } + if sources != nil { + _ = json.Unmarshal(sources, &r.Sources) + } + if sourceRes != nil { + _ = json.Unmarshal(sourceRes, &r.SourceResource) + } + if ports != nil { + _ = json.Unmarshal(ports, &r.Ports) + } + if portRanges != nil { + _ = json.Unmarshal(portRanges, &r.PortRanges) + } + } + return &r, err + }) + if err != nil { + return nil, err + } + return rules, nil +} + +func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]types.GroupPeer, error) { + if len(groupIDs) == 0 { + return nil, nil + } + const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, groupIDs) + if err != nil { + return nil, err + } + groupPeers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer]) + if err != nil { + return nil, err + } + return groupPeers, nil +} + func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) { var user types.User result := s.db.Select("account_id").Take(&user, idQueryCondition, userID) @@ -1200,8 +2195,41 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe if err != nil { return nil, err } + pool, err := connectToPgDb(context.Background(), dsn) + if err != nil { + return nil, err + } + store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) + if err != nil { + pool.Close() + return nil, err + } + store.pool = pool + return store, nil +} - return NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) +func connectToPgDb(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = pgMaxConnections + config.MinConns = pgMinConnections + config.MaxConnLifetime = pgMaxConnLifetime + config.HealthCheckPeriod = pgHealthCheckPeriod + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + + return pool, nil } // NewMysqlStore creates a new MySQL store. diff --git a/management/server/store/sql_store_get_account_test.go b/management/server/store/sql_store_get_account_test.go new file mode 100644 index 000000000..8ff04d68a --- /dev/null +++ b/management/server/store/sql_store_get_account_test.go @@ -0,0 +1,1089 @@ +package store + +import ( + "context" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/integration_reference" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// TestGetAccount_ComprehensiveFieldValidation validates that GetAccount properly loads +// all fields and nested objects from the database, including deeply nested structures. +func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) { + if testing.Short() { + t.Skip("skipping comprehensive test in short mode") + } + + ctx := context.Background() + store, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + defer cleanup() + + // Create comprehensive test data + accountID := "test-account-comprehensive" + userID1 := "user-1" + userID2 := "user-2" + peerID1 := "peer-1" + peerID2 := "peer-2" + peerID3 := "peer-3" + groupID1 := "group-1" + groupID2 := "group-2" + setupKeyID1 := "setup-key-1" + setupKeyID2 := "setup-key-2" + routeID1 := route.ID("route-1") + routeID2 := route.ID("route-2") + nsGroupID1 := "ns-group-1" + nsGroupID2 := "ns-group-2" + policyID1 := "policy-1" + policyID2 := "policy-2" + postureCheckID1 := "posture-check-1" + postureCheckID2 := "posture-check-2" + networkID1 := "network-1" + routerID1 := "router-1" + resourceID1 := "resource-1" + patID1 := "pat-1" + patID2 := "pat-2" + patID3 := "pat-3" + + now := time.Now().UTC().Truncate(time.Second) + lastLogin := now.Add(-24 * time.Hour) + patLastUsed := now.Add(-1 * time.Hour) + + // Build comprehensive account with all fields populated + account := &types.Account{ + Id: accountID, + CreatedBy: userID1, + CreatedAt: now, + Domain: "example.com", + DomainCategory: "business", + IsDomainPrimaryAccount: true, + Network: &types.Network{ + Identifier: "test-network", + Net: net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }, + Dns: "test-dns", + Serial: 42, + }, + DNSSettings: types.DNSSettings{ + DisabledManagementGroups: []string{"dns-group-1", "dns-group-2"}, + }, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: time.Hour * 24 * 30, + GroupsPropagationEnabled: true, + JWTGroupsEnabled: true, + JWTGroupsClaimName: "groups", + JWTAllowGroups: []string{"allowed-group-1", "allowed-group-2"}, + RegularUsersViewBlocked: false, + Extra: &types.ExtraSettings{ + PeerApprovalEnabled: true, + IntegratedValidatorGroups: []string{"validator-1"}, + }, + }, + } + + // Create Setup Keys with all fields + setupKey1ExpiresAt := now.Add(30 * 24 * time.Hour) + setupKey1LastUsed := now.Add(-2 * time.Hour) + setupKey1 := &types.SetupKey{ + Id: setupKeyID1, + AccountID: accountID, + Key: "setup-key-secret-1", + Name: "Setup Key 1", + Type: types.SetupKeyReusable, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: &setupKey1ExpiresAt, + Revoked: false, + UsedTimes: 5, + LastUsed: &setupKey1LastUsed, + AutoGroups: []string{groupID1, groupID2}, + UsageLimit: 100, + Ephemeral: false, + } + + setupKey2ExpiresAt := now.Add(7 * 24 * time.Hour) + setupKey2LastUsed := now.Add(-1 * time.Hour) + setupKey2 := &types.SetupKey{ + Id: setupKeyID2, + AccountID: accountID, + Key: "setup-key-secret-2", + Name: "Setup Key 2 (One-off)", + Type: types.SetupKeyOneOff, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: &setupKey2ExpiresAt, + Revoked: true, + UsedTimes: 1, + LastUsed: &setupKey2LastUsed, + AutoGroups: []string{}, + UsageLimit: 1, + Ephemeral: true, + } + + account.SetupKeys = map[string]*types.SetupKey{ + setupKey1.Key: setupKey1, + setupKey2.Key: setupKey2, + } + + // Create Peers with comprehensive fields + peer1 := &nbpeer.Peer{ + ID: peerID1, + AccountID: accountID, + Key: "peer-key-1-AAAA", + Name: "Peer 1", + IP: net.ParseIP("100.64.0.1"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer1.example.com", + GoOS: "linux", + Kernel: "5.15.0", + Core: "x86_64", + Platform: "ubuntu", + OS: "Ubuntu 22.04", + WtVersion: "0.24.0", + UIVersion: "0.24.0", + KernelVersion: "5.15.0-78-generic", + OSVersion: "22.04", + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.1.10/32"), Mac: "00:11:22:33:44:55"}, + {NetIP: netip.MustParsePrefix("10.0.0.5/32"), Mac: "00:11:22:33:44:66"}, + }, + SystemSerialNumber: "ABC123", + SystemProductName: "Server Model X", + SystemManufacturer: "Dell Inc.", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-5 * time.Minute), + Connected: true, + LoginExpired: false, + RequiresApproval: false, + }, + Location: nbpeer.Location{ + ConnectionIP: net.ParseIP("203.0.113.10"), + CountryCode: "US", + CityName: "San Francisco", + GeoNameID: 5391959, + }, + SSHEnabled: true, + SSHKey: "ssh-rsa AAAAB3NzaC1...", + UserID: userID1, + LoginExpirationEnabled: true, + InactivityExpirationEnabled: false, + DNSLabel: "peer1", + CreatedAt: now.Add(-30 * 24 * time.Hour), + Ephemeral: false, + } + + peer2 := &nbpeer.Peer{ + ID: peerID2, + AccountID: accountID, + Key: "peer-key-2-BBBB", + Name: "Peer 2", + IP: net.ParseIP("100.64.0.2"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer2.example.com", + GoOS: "darwin", + Kernel: "22.0.0", + Core: "arm64", + Platform: "darwin", + OS: "macOS Ventura", + WtVersion: "0.24.0", + UIVersion: "0.24.0", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-1 * time.Hour), + Connected: false, + LoginExpired: true, + RequiresApproval: true, + }, + Location: nbpeer.Location{ + ConnectionIP: net.ParseIP("198.51.100.20"), + CountryCode: "GB", + CityName: "London", + GeoNameID: 2643743, + }, + SSHEnabled: false, + UserID: userID2, + LoginExpirationEnabled: false, + InactivityExpirationEnabled: true, + DNSLabel: "peer2", + CreatedAt: now.Add(-15 * 24 * time.Hour), + Ephemeral: false, + } + + peer3 := &nbpeer.Peer{ + ID: peerID3, + AccountID: accountID, + Key: "peer-key-3-CCCC", + Name: "Peer 3 (Ephemeral)", + IP: net.ParseIP("100.64.0.3"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer3.example.com", + GoOS: "windows", + Platform: "windows", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-10 * time.Minute), + Connected: true, + }, + DNSLabel: "peer3", + CreatedAt: now.Add(-1 * time.Hour), + Ephemeral: true, + } + + account.Peers = map[string]*nbpeer.Peer{ + peerID1: peer1, + peerID2: peer2, + peerID3: peer3, + } + + // Create Users with PATs + pat1ExpirationDate := now.Add(90 * 24 * time.Hour) + pat1 := &types.PersonalAccessToken{ + ID: patID1, + Name: "PAT 1", + HashedToken: "hashed-token-1", + ExpirationDate: &pat1ExpirationDate, + CreatedAt: now.Add(-10 * 24 * time.Hour), + CreatedBy: userID1, + LastUsed: &patLastUsed, + } + + pat2ExpirationDate := now.Add(30 * 24 * time.Hour) + pat2 := &types.PersonalAccessToken{ + ID: patID2, + Name: "PAT 2", + HashedToken: "hashed-token-2", + ExpirationDate: &pat2ExpirationDate, + CreatedAt: now.Add(-5 * 24 * time.Hour), + CreatedBy: userID1, + } + + pat3ExpirationDate := now.Add(60 * 24 * time.Hour) + pat3 := &types.PersonalAccessToken{ + ID: patID3, + Name: "PAT 3", + HashedToken: "hashed-token-3", + ExpirationDate: &pat3ExpirationDate, + CreatedAt: now.Add(-2 * 24 * time.Hour), + CreatedBy: userID2, + } + + user1 := &types.User{ + Id: userID1, + AccountID: accountID, + Role: types.UserRoleOwner, + IsServiceUser: false, + NonDeletable: true, + AutoGroups: []string{groupID1}, + Issued: types.UserIssuedAPI, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 123, + IntegrationType: "azure_ad", + }, + CreatedAt: now.Add(-60 * 24 * time.Hour), + LastLogin: &lastLogin, + Blocked: false, + PATs: map[string]*types.PersonalAccessToken{ + patID1: pat1, + patID2: pat2, + }, + } + + user2 := &types.User{ + Id: userID2, + AccountID: accountID, + Role: types.UserRoleAdmin, + IsServiceUser: true, + NonDeletable: false, + AutoGroups: []string{groupID2}, + Issued: types.UserIssuedIntegration, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 456, + IntegrationType: "google_workspace", + }, + CreatedAt: now.Add(-30 * 24 * time.Hour), + Blocked: false, + PATs: map[string]*types.PersonalAccessToken{ + patID3: pat3, + }, + } + + account.Users = map[string]*types.User{ + userID1: user1, + userID2: user2, + } + + // Create Groups with peers and resources + group1 := &types.Group{ + ID: groupID1, + AccountID: accountID, + Name: "Group 1", + Issued: types.GroupIssuedAPI, + Peers: []string{peerID1, peerID2}, + Resources: []types.Resource{ + { + ID: "resource-1", + Type: types.ResourceTypeHost, + }, + }, + } + + group2 := &types.Group{ + ID: groupID2, + AccountID: accountID, + Name: "Group 2", + Issued: types.GroupIssuedIntegration, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 789, + IntegrationType: "okta", + }, + Peers: []string{peerID3}, + Resources: []types.Resource{}, + } + + account.Groups = map[string]*types.Group{ + groupID1: group1, + groupID2: group2, + } + + // Create Policies with Rules + policy1 := &types.Policy{ + ID: policyID1, + AccountID: accountID, + Name: "Policy 1", + Description: "Main access policy", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "rule-1", + PolicyID: policyID1, + Name: "Rule 1", + Description: "Allow access", + Enabled: true, + Action: types.PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Ports: []string{}, + PortRanges: []types.RulePortRange{}, + Sources: []string{groupID1}, + Destinations: []string{groupID2}, + }, + { + ID: "rule-2", + PolicyID: policyID1, + Name: "Rule 2", + Description: "Block traffic on specific ports", + Enabled: true, + Action: types.PolicyTrafficActionDrop, + Bidirectional: false, + Protocol: types.PolicyRuleProtocolTCP, + Ports: []string{"22", "3389"}, + PortRanges: []types.RulePortRange{ + {Start: 8000, End: 8999}, + }, + Sources: []string{groupID2}, + Destinations: []string{groupID1}, + }, + }, + } + + policy2 := &types.Policy{ + ID: policyID2, + AccountID: accountID, + Name: "Policy 2", + Description: "Secondary policy", + Enabled: false, + Rules: []*types.PolicyRule{ + { + ID: "rule-3", + PolicyID: policyID2, + Name: "Rule 3", + Description: "UDP access", + Enabled: false, + Action: types.PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolUDP, + Ports: []string{"53"}, + Sources: []string{groupID1}, + Destinations: []string{groupID1}, + }, + }, + } + + account.Policies = []*types.Policy{policy1, policy2} + + // Create Routes + route1 := &route.Route{ + ID: routeID1, + AccountID: accountID, + Network: netip.MustParsePrefix("10.0.0.0/24"), + NetworkType: route.IPv4Network, + Peer: peerID1, + PeerGroups: []string{}, + Description: "Route 1", + NetID: "net-id-1", + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{groupID1}, + AccessControlGroups: []string{groupID2}, + } + + route2 := &route.Route{ + ID: routeID2, + AccountID: accountID, + Network: netip.MustParsePrefix("192.168.1.0/24"), + NetworkType: route.IPv4Network, + Peer: "", + PeerGroups: []string{groupID2}, + Description: "Route 2 (High Availability)", + NetID: "net-id-2", + Masquerade: false, + Metric: 100, + Enabled: true, + Groups: []string{groupID1, groupID2}, + AccessControlGroups: []string{groupID1}, + } + + account.Routes = map[route.ID]*route.Route{ + routeID1: route1, + routeID2: route2, + } + + // Create NameServer Groups + nsGroup1 := &nbdns.NameServerGroup{ + ID: nsGroupID1, + AccountID: accountID, + Name: "NS Group 1", + Description: "Primary nameservers", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + { + IP: netip.MustParseAddr("8.8.4.4"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + Groups: []string{groupID1, groupID2}, + Domains: []string{"example.com", "test.com"}, + Enabled: true, + Primary: true, + SearchDomainsEnabled: true, + } + + nsGroup2 := &nbdns.NameServerGroup{ + ID: nsGroupID2, + AccountID: accountID, + Name: "NS Group 2", + Description: "Secondary nameservers", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + Groups: []string{}, + Domains: []string{}, + Enabled: false, + Primary: false, + SearchDomainsEnabled: false, + } + + account.NameServerGroups = map[string]*nbdns.NameServerGroup{ + nsGroupID1: nsGroup1, + nsGroupID2: nsGroup2, + } + + // Create Posture Checks + postureCheck1 := &posture.Checks{ + ID: postureCheckID1, + AccountID: accountID, + Name: "Posture Check 1", + Description: "OS version check", + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.24.0", + }, + OSVersionCheck: &posture.OSVersionCheck{ + Ios: &posture.MinVersionCheck{ + MinVersion: "16.0", + }, + Darwin: &posture.MinVersionCheck{ + MinVersion: "22.0.0", + }, + }, + }, + } + + postureCheck2 := &posture.Checks{ + ID: postureCheckID2, + AccountID: accountID, + Name: "Posture Check 2", + Description: "Geo location check", + Checks: posture.ChecksDefinition{ + GeoLocationCheck: &posture.GeoLocationCheck{ + Locations: []posture.Location{ + { + CountryCode: "US", + CityName: "San Francisco", + }, + { + CountryCode: "GB", + CityName: "London", + }, + }, + Action: "allow", + }, + PeerNetworkRangeCheck: &posture.PeerNetworkRangeCheck{ + Ranges: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + Action: "allow", + }, + }, + } + + account.PostureChecks = []*posture.Checks{postureCheck1, postureCheck2} + + // Create Networks + network1 := &networkTypes.Network{ + ID: networkID1, + AccountID: accountID, + Name: "Network 1", + Description: "Primary network", + } + + account.Networks = []*networkTypes.Network{network1} + + // Create Network Routers + router1 := &routerTypes.NetworkRouter{ + ID: routerID1, + AccountID: accountID, + NetworkID: networkID1, + Peer: peerID1, + PeerGroups: []string{}, + Masquerade: true, + Metric: 100, + } + + account.NetworkRouters = []*routerTypes.NetworkRouter{router1} + + // Create Network Resources + resource1 := &resourceTypes.NetworkResource{ + ID: resourceID1, + AccountID: accountID, + NetworkID: networkID1, + Name: "Resource 1", + Description: "Web server", + Prefix: netip.MustParsePrefix("192.168.1.100/32"), + Type: resourceTypes.Host, + } + + account.NetworkResources = []*resourceTypes.NetworkResource{resource1} + + // Create Onboarding + account.Onboarding = types.AccountOnboarding{ + AccountID: accountID, + OnboardingFlowPending: true, + SignupFormPending: false, + CreatedAt: now, + UpdatedAt: now, + } + + // Save the account to the database + err = store.SaveAccount(ctx, account) + require.NoError(t, err, "Failed to save comprehensive test account") + + // Retrieve the account from the database + retrievedAccount, err := store.GetAccount(ctx, accountID) + require.NoError(t, err, "Failed to retrieve account") + require.NotNil(t, retrievedAccount, "Retrieved account should not be nil") + + // ========== VALIDATE TOP-LEVEL FIELDS ========== + t.Run("TopLevelFields", func(t *testing.T) { + assert.Equal(t, accountID, retrievedAccount.Id, "Account ID mismatch") + assert.Equal(t, userID1, retrievedAccount.CreatedBy, "CreatedBy mismatch") + assert.WithinDuration(t, now, retrievedAccount.CreatedAt, time.Second, "CreatedAt mismatch") + assert.Equal(t, "example.com", retrievedAccount.Domain, "Domain mismatch") + assert.Equal(t, "business", retrievedAccount.DomainCategory, "DomainCategory mismatch") + assert.True(t, retrievedAccount.IsDomainPrimaryAccount, "IsDomainPrimaryAccount should be true") + }) + + // ========== VALIDATE EMBEDDED NETWORK ========== + t.Run("EmbeddedNetwork", func(t *testing.T) { + require.NotNil(t, retrievedAccount.Network, "Network should not be nil") + assert.Equal(t, "test-network", retrievedAccount.Network.Identifier, "Network Identifier mismatch") + assert.Equal(t, "test-dns", retrievedAccount.Network.Dns, "Network DNS mismatch") + assert.Equal(t, uint64(42), retrievedAccount.Network.Serial, "Network Serial mismatch") + + expectedIP := net.ParseIP("100.64.0.0") + assert.True(t, retrievedAccount.Network.Net.IP.Equal(expectedIP), "Network IP mismatch") + expectedMask := net.CIDRMask(10, 32) + assert.Equal(t, expectedMask, retrievedAccount.Network.Net.Mask, "Network Mask mismatch") + }) + + // ========== VALIDATE DNS SETTINGS ========== + t.Run("DNSSettings", func(t *testing.T) { + assert.Len(t, retrievedAccount.DNSSettings.DisabledManagementGroups, 2, "DisabledManagementGroups length mismatch") + assert.Contains(t, retrievedAccount.DNSSettings.DisabledManagementGroups, "dns-group-1", "Missing dns-group-1") + assert.Contains(t, retrievedAccount.DNSSettings.DisabledManagementGroups, "dns-group-2", "Missing dns-group-2") + }) + + // ========== VALIDATE SETTINGS ========== + t.Run("Settings", func(t *testing.T) { + require.NotNil(t, retrievedAccount.Settings, "Settings should not be nil") + assert.True(t, retrievedAccount.Settings.PeerLoginExpirationEnabled, "PeerLoginExpirationEnabled mismatch") + assert.Equal(t, time.Hour*24*30, retrievedAccount.Settings.PeerLoginExpiration, "PeerLoginExpiration mismatch") + assert.True(t, retrievedAccount.Settings.GroupsPropagationEnabled, "GroupsPropagationEnabled mismatch") + assert.True(t, retrievedAccount.Settings.JWTGroupsEnabled, "JWTGroupsEnabled mismatch") + assert.Equal(t, "groups", retrievedAccount.Settings.JWTGroupsClaimName, "JWTGroupsClaimName mismatch") + assert.Len(t, retrievedAccount.Settings.JWTAllowGroups, 2, "JWTAllowGroups length mismatch") + assert.Contains(t, retrievedAccount.Settings.JWTAllowGroups, "allowed-group-1") + assert.Contains(t, retrievedAccount.Settings.JWTAllowGroups, "allowed-group-2") + assert.False(t, retrievedAccount.Settings.RegularUsersViewBlocked, "RegularUsersViewBlocked mismatch") + + // Validate Extra Settings + require.NotNil(t, retrievedAccount.Settings.Extra, "Extra settings should not be nil") + assert.True(t, retrievedAccount.Settings.Extra.PeerApprovalEnabled, "PeerApprovalEnabled mismatch") + assert.Len(t, retrievedAccount.Settings.Extra.IntegratedValidatorGroups, 1, "IntegratedValidatorGroups length mismatch") + assert.Equal(t, "validator-1", retrievedAccount.Settings.Extra.IntegratedValidatorGroups[0]) + }) + + // ========== VALIDATE SETUP KEYS ========== + t.Run("SetupKeys", func(t *testing.T) { + require.Len(t, retrievedAccount.SetupKeys, 2, "Should have 2 setup keys") + + // Validate Setup Key 1 + sk1, exists := retrievedAccount.SetupKeys["setup-key-secret-1"] + require.True(t, exists, "Setup key 1 should exist") + assert.Equal(t, "Setup Key 1", sk1.Name, "Setup key 1 name mismatch") + assert.Equal(t, types.SetupKeyReusable, sk1.Type, "Setup key 1 type mismatch") + assert.False(t, sk1.Revoked, "Setup key 1 should not be revoked") + assert.Equal(t, 5, sk1.UsedTimes, "Setup key 1 used times mismatch") + assert.Equal(t, 100, sk1.UsageLimit, "Setup key 1 usage limit mismatch") + assert.False(t, sk1.Ephemeral, "Setup key 1 should not be ephemeral") + assert.Len(t, sk1.AutoGroups, 2, "Setup key 1 auto groups length mismatch") + assert.Contains(t, sk1.AutoGroups, groupID1) + assert.Contains(t, sk1.AutoGroups, groupID2) + + // Validate Setup Key 2 + sk2, exists := retrievedAccount.SetupKeys["setup-key-secret-2"] + require.True(t, exists, "Setup key 2 should exist") + assert.Equal(t, "Setup Key 2 (One-off)", sk2.Name, "Setup key 2 name mismatch") + assert.Equal(t, types.SetupKeyOneOff, sk2.Type, "Setup key 2 type mismatch") + assert.True(t, sk2.Revoked, "Setup key 2 should be revoked") + assert.Equal(t, 1, sk2.UsedTimes, "Setup key 2 used times mismatch") + assert.Equal(t, 1, sk2.UsageLimit, "Setup key 2 usage limit mismatch") + assert.True(t, sk2.Ephemeral, "Setup key 2 should be ephemeral") + assert.Len(t, sk2.AutoGroups, 0, "Setup key 2 should have empty auto groups") + }) + + // ========== VALIDATE PEERS ========== + t.Run("Peers", func(t *testing.T) { + require.Len(t, retrievedAccount.Peers, 3, "Should have 3 peers") + + // Validate Peer 1 + p1, exists := retrievedAccount.Peers[peerID1] + require.True(t, exists, "Peer 1 should exist") + assert.Equal(t, "Peer 1", p1.Name, "Peer 1 name mismatch") + assert.Equal(t, "peer-key-1-AAAA", p1.Key, "Peer 1 key mismatch") + assert.True(t, p1.IP.Equal(net.ParseIP("100.64.0.1")), "Peer 1 IP mismatch") + assert.Equal(t, userID1, p1.UserID, "Peer 1 user ID mismatch") + assert.True(t, p1.SSHEnabled, "Peer 1 SSH should be enabled") + assert.Equal(t, "ssh-rsa AAAAB3NzaC1...", p1.SSHKey, "Peer 1 SSH key mismatch") + assert.True(t, p1.LoginExpirationEnabled, "Peer 1 login expiration should be enabled") + assert.False(t, p1.Ephemeral, "Peer 1 should not be ephemeral") + assert.Equal(t, "peer1", p1.DNSLabel, "Peer 1 DNS label mismatch") + + // Validate Peer 1 Meta + assert.Equal(t, "peer1.example.com", p1.Meta.Hostname, "Peer 1 hostname mismatch") + assert.Equal(t, "linux", p1.Meta.GoOS, "Peer 1 OS mismatch") + assert.Equal(t, "5.15.0", p1.Meta.Kernel, "Peer 1 kernel mismatch") + assert.Equal(t, "x86_64", p1.Meta.Core, "Peer 1 core mismatch") + assert.Equal(t, "ubuntu", p1.Meta.Platform, "Peer 1 platform mismatch") + assert.Equal(t, "Ubuntu 22.04", p1.Meta.OS, "Peer 1 OS version mismatch") + assert.Equal(t, "0.24.0", p1.Meta.WtVersion, "Peer 1 wt version mismatch") + assert.Equal(t, "ABC123", p1.Meta.SystemSerialNumber, "Peer 1 serial number mismatch") + assert.Equal(t, "Server Model X", p1.Meta.SystemProductName, "Peer 1 product name mismatch") + assert.Equal(t, "Dell Inc.", p1.Meta.SystemManufacturer, "Peer 1 manufacturer mismatch") + + // Validate Network Addresses + assert.Len(t, p1.Meta.NetworkAddresses, 2, "Peer 1 should have 2 network addresses") + assert.Equal(t, netip.MustParsePrefix("192.168.1.10/32"), p1.Meta.NetworkAddresses[0].NetIP, "Network address 1 IP mismatch") + assert.Equal(t, "00:11:22:33:44:55", p1.Meta.NetworkAddresses[0].Mac, "Network address 1 MAC mismatch") + assert.Equal(t, netip.MustParsePrefix("10.0.0.5/32"), p1.Meta.NetworkAddresses[1].NetIP, "Network address 2 IP mismatch") + assert.Equal(t, "00:11:22:33:44:66", p1.Meta.NetworkAddresses[1].Mac, "Network address 2 MAC mismatch") + + // Validate Peer 1 Status + require.NotNil(t, p1.Status, "Peer 1 status should not be nil") + assert.True(t, p1.Status.Connected, "Peer 1 should be connected") + assert.False(t, p1.Status.LoginExpired, "Peer 1 login should not be expired") + assert.False(t, p1.Status.RequiresApproval, "Peer 1 should not require approval") + + // Validate Peer 1 Location + assert.True(t, p1.Location.ConnectionIP.Equal(net.ParseIP("203.0.113.10")), "Peer 1 connection IP mismatch") + assert.Equal(t, "US", p1.Location.CountryCode, "Peer 1 country code mismatch") + assert.Equal(t, "San Francisco", p1.Location.CityName, "Peer 1 city name mismatch") + assert.Equal(t, uint(5391959), p1.Location.GeoNameID, "Peer 1 geo name ID mismatch") + + // Validate Peer 2 + p2, exists := retrievedAccount.Peers[peerID2] + require.True(t, exists, "Peer 2 should exist") + assert.Equal(t, "Peer 2", p2.Name, "Peer 2 name mismatch") + assert.Equal(t, "peer-key-2-BBBB", p2.Key, "Peer 2 key mismatch") + assert.False(t, p2.SSHEnabled, "Peer 2 SSH should be disabled") + assert.False(t, p2.LoginExpirationEnabled, "Peer 2 login expiration should be disabled") + assert.True(t, p2.InactivityExpirationEnabled, "Peer 2 inactivity expiration should be enabled") + + // Validate Peer 2 Status + require.NotNil(t, p2.Status, "Peer 2 status should not be nil") + assert.False(t, p2.Status.Connected, "Peer 2 should not be connected") + assert.True(t, p2.Status.LoginExpired, "Peer 2 login should be expired") + assert.True(t, p2.Status.RequiresApproval, "Peer 2 should require approval") + + // Validate Peer 3 (Ephemeral) + p3, exists := retrievedAccount.Peers[peerID3] + require.True(t, exists, "Peer 3 should exist") + assert.True(t, p3.Ephemeral, "Peer 3 should be ephemeral") + assert.Equal(t, "Peer 3 (Ephemeral)", p3.Name, "Peer 3 name mismatch") + }) + + // ========== VALIDATE USERS ========== + t.Run("Users", func(t *testing.T) { + require.Len(t, retrievedAccount.Users, 2, "Should have 2 users") + + // Validate User 1 + u1, exists := retrievedAccount.Users[userID1] + require.True(t, exists, "User 1 should exist") + assert.Equal(t, types.UserRoleOwner, u1.Role, "User 1 role mismatch") + assert.False(t, u1.IsServiceUser, "User 1 should not be a service user") + assert.True(t, u1.NonDeletable, "User 1 should be non-deletable") + assert.Equal(t, types.UserIssuedAPI, u1.Issued, "User 1 issued type mismatch") + assert.Len(t, u1.AutoGroups, 1, "User 1 auto groups length mismatch") + assert.Contains(t, u1.AutoGroups, groupID1, "User 1 should have group1") + assert.False(t, u1.Blocked, "User 1 should not be blocked") + require.NotNil(t, u1.LastLogin, "User 1 last login should not be nil") + assert.WithinDuration(t, lastLogin, *u1.LastLogin, time.Second, "User 1 last login mismatch") + + // Validate User 1 Integration Reference + assert.Equal(t, 123, u1.IntegrationReference.ID, "User 1 integration ID mismatch") + assert.Equal(t, "azure_ad", u1.IntegrationReference.IntegrationType, "User 1 integration type mismatch") + + // Validate User 1 PATs + require.Len(t, u1.PATs, 2, "User 1 should have 2 PATs") + + pat1Retrieved, exists := u1.PATs[patID1] + require.True(t, exists, "PAT 1 should exist") + assert.Equal(t, "PAT 1", pat1Retrieved.Name, "PAT 1 name mismatch") + assert.Equal(t, "hashed-token-1", pat1Retrieved.HashedToken, "PAT 1 hashed token mismatch") + require.NotNil(t, pat1Retrieved.LastUsed, "PAT 1 last used should not be nil") + assert.WithinDuration(t, patLastUsed, *pat1Retrieved.LastUsed, time.Second, "PAT 1 last used mismatch") + assert.Equal(t, userID1, pat1Retrieved.CreatedBy, "PAT 1 created by mismatch") + assert.Empty(t, pat1Retrieved.UserID, "PAT 1 UserID should be cleared") + + pat2Retrieved, exists := u1.PATs[patID2] + require.True(t, exists, "PAT 2 should exist") + assert.Equal(t, "PAT 2", pat2Retrieved.Name, "PAT 2 name mismatch") + assert.Nil(t, pat2Retrieved.LastUsed, "PAT 2 last used should be nil") + + // Validate User 2 + u2, exists := retrievedAccount.Users[userID2] + require.True(t, exists, "User 2 should exist") + assert.Equal(t, types.UserRoleAdmin, u2.Role, "User 2 role mismatch") + assert.True(t, u2.IsServiceUser, "User 2 should be a service user") + assert.False(t, u2.NonDeletable, "User 2 should be deletable") + assert.Equal(t, types.UserIssuedIntegration, u2.Issued, "User 2 issued type mismatch") + assert.Equal(t, "google_workspace", u2.IntegrationReference.IntegrationType, "User 2 integration type mismatch") + + // Validate User 2 PATs + require.Len(t, u2.PATs, 1, "User 2 should have 1 PAT") + pat3Retrieved, exists := u2.PATs[patID3] + require.True(t, exists, "PAT 3 should exist") + assert.Equal(t, "PAT 3", pat3Retrieved.Name, "PAT 3 name mismatch") + }) + + // ========== VALIDATE GROUPS ========== + t.Run("Groups", func(t *testing.T) { + require.Len(t, retrievedAccount.Groups, 2, "Should have 2 groups") + + // Validate Group 1 + g1, exists := retrievedAccount.Groups[groupID1] + require.True(t, exists, "Group 1 should exist") + assert.Equal(t, "Group 1", g1.Name, "Group 1 name mismatch") + assert.Equal(t, types.GroupIssuedAPI, g1.Issued, "Group 1 issued type mismatch") + assert.Len(t, g1.Peers, 2, "Group 1 should have 2 peers") + assert.Contains(t, g1.Peers, peerID1, "Group 1 should contain peer 1") + assert.Contains(t, g1.Peers, peerID2, "Group 1 should contain peer 2") + + // Validate Group 1 Resources + assert.Len(t, g1.Resources, 1, "Group 1 should have 1 resource") + assert.Equal(t, "resource-1", g1.Resources[0].ID, "Group 1 resource ID mismatch") + assert.Equal(t, types.ResourceTypeHost, g1.Resources[0].Type, "Group 1 resource type mismatch") + + // Validate Group 2 + g2, exists := retrievedAccount.Groups[groupID2] + require.True(t, exists, "Group 2 should exist") + assert.Equal(t, "Group 2", g2.Name, "Group 2 name mismatch") + assert.Equal(t, types.GroupIssuedIntegration, g2.Issued, "Group 2 issued type mismatch") + assert.Len(t, g2.Peers, 1, "Group 2 should have 1 peer") + assert.Contains(t, g2.Peers, peerID3, "Group 2 should contain peer 3") + assert.Len(t, g2.Resources, 0, "Group 2 should have 0 resources") + + // Validate Group 2 Integration Reference + assert.Equal(t, 789, g2.IntegrationReference.ID, "Group 2 integration ID mismatch") + assert.Equal(t, "okta", g2.IntegrationReference.IntegrationType, "Group 2 integration type mismatch") + }) + + // ========== VALIDATE POLICIES ========== + t.Run("Policies", func(t *testing.T) { + require.Len(t, retrievedAccount.Policies, 2, "Should have 2 policies") + + // Validate Policy 1 + pol1 := retrievedAccount.Policies[0] + if pol1.ID != policyID1 { + pol1 = retrievedAccount.Policies[1] + } + assert.Equal(t, policyID1, pol1.ID, "Policy 1 ID mismatch") + assert.Equal(t, "Policy 1", pol1.Name, "Policy 1 name mismatch") + assert.Equal(t, "Main access policy", pol1.Description, "Policy 1 description mismatch") + assert.True(t, pol1.Enabled, "Policy 1 should be enabled") + + // Validate Policy 1 Rules + require.Len(t, pol1.Rules, 2, "Policy 1 should have 2 rules") + + rule1 := pol1.Rules[0] + assert.Equal(t, "Rule 1", rule1.Name, "Rule 1 name mismatch") + assert.Equal(t, "Allow access", rule1.Description, "Rule 1 description mismatch") + assert.True(t, rule1.Enabled, "Rule 1 should be enabled") + assert.Equal(t, types.PolicyTrafficActionAccept, rule1.Action, "Rule 1 action mismatch") + assert.True(t, rule1.Bidirectional, "Rule 1 should be bidirectional") + assert.Equal(t, types.PolicyRuleProtocolALL, rule1.Protocol, "Rule 1 protocol mismatch") + assert.Len(t, rule1.Sources, 1, "Rule 1 sources length mismatch") + assert.Contains(t, rule1.Sources, groupID1, "Rule 1 should have group1 as source") + assert.Len(t, rule1.Destinations, 1, "Rule 1 destinations length mismatch") + assert.Contains(t, rule1.Destinations, groupID2, "Rule 1 should have group2 as destination") + + rule2 := pol1.Rules[1] + assert.Equal(t, "Rule 2", rule2.Name, "Rule 2 name mismatch") + assert.Equal(t, types.PolicyTrafficActionDrop, rule2.Action, "Rule 2 action mismatch") + assert.False(t, rule2.Bidirectional, "Rule 2 should not be bidirectional") + assert.Equal(t, types.PolicyRuleProtocolTCP, rule2.Protocol, "Rule 2 protocol mismatch") + assert.Len(t, rule2.Ports, 2, "Rule 2 ports length mismatch") + assert.Contains(t, rule2.Ports, "22", "Rule 2 should have port 22") + assert.Contains(t, rule2.Ports, "3389", "Rule 2 should have port 3389") + assert.Len(t, rule2.PortRanges, 1, "Rule 2 port ranges length mismatch") + assert.Equal(t, uint16(8000), rule2.PortRanges[0].Start, "Rule 2 port range start mismatch") + assert.Equal(t, uint16(8999), rule2.PortRanges[0].End, "Rule 2 port range end mismatch") + + // Validate Policy 2 + pol2 := retrievedAccount.Policies[1] + if pol2.ID != policyID2 { + pol2 = retrievedAccount.Policies[0] + } + assert.Equal(t, policyID2, pol2.ID, "Policy 2 ID mismatch") + assert.Equal(t, "Policy 2", pol2.Name, "Policy 2 name mismatch") + assert.False(t, pol2.Enabled, "Policy 2 should be disabled") + require.Len(t, pol2.Rules, 1, "Policy 2 should have 1 rule") + + rule3 := pol2.Rules[0] + assert.Equal(t, "Rule 3", rule3.Name, "Rule 3 name mismatch") + assert.False(t, rule3.Enabled, "Rule 3 should be disabled") + assert.Equal(t, types.PolicyRuleProtocolUDP, rule3.Protocol, "Rule 3 protocol mismatch") + }) + + // ========== VALIDATE ROUTES ========== + t.Run("Routes", func(t *testing.T) { + require.Len(t, retrievedAccount.Routes, 2, "Should have 2 routes") + + // Validate Route 1 + r1, exists := retrievedAccount.Routes[routeID1] + require.True(t, exists, "Route 1 should exist") + assert.Equal(t, "Route 1", r1.Description, "Route 1 description mismatch") + assert.Equal(t, route.IPv4Network, r1.NetworkType, "Route 1 network type mismatch") + assert.Equal(t, peerID1, r1.Peer, "Route 1 peer mismatch") + assert.Empty(t, r1.PeerGroups, "Route 1 peer groups should be empty") + assert.Equal(t, route.NetID("net-id-1"), r1.NetID, "Route 1 net ID mismatch") + assert.True(t, r1.Masquerade, "Route 1 masquerade should be enabled") + assert.Equal(t, 9999, r1.Metric, "Route 1 metric mismatch") + assert.True(t, r1.Enabled, "Route 1 should be enabled") + assert.Len(t, r1.Groups, 1, "Route 1 groups length mismatch") + assert.Contains(t, r1.Groups, groupID1, "Route 1 should have group1") + assert.Len(t, r1.AccessControlGroups, 1, "Route 1 ACL groups length mismatch") + assert.Contains(t, r1.AccessControlGroups, groupID2, "Route 1 should have group2 in ACL") + + // Validate Route 1 Network CIDR + assert.Equal(t, "10.0.0.0/24", r1.Network.String(), "Route 1 network CIDR mismatch") + + // Validate Route 2 + r2, exists := retrievedAccount.Routes[routeID2] + require.True(t, exists, "Route 2 should exist") + assert.Equal(t, "Route 2 (High Availability)", r2.Description, "Route 2 description mismatch") + assert.Empty(t, r2.Peer, "Route 2 peer should be empty") + assert.Len(t, r2.PeerGroups, 1, "Route 2 peer groups length mismatch") + assert.Contains(t, r2.PeerGroups, groupID2, "Route 2 should have group2 as peer group") + assert.False(t, r2.Masquerade, "Route 2 masquerade should be disabled") + assert.Equal(t, 100, r2.Metric, "Route 2 metric mismatch") + assert.Equal(t, "192.168.1.0/24", r2.Network.String(), "Route 2 network CIDR mismatch") + }) + + // ========== VALIDATE NAME SERVER GROUPS ========== + t.Run("NameServerGroups", func(t *testing.T) { + require.Len(t, retrievedAccount.NameServerGroups, 2, "Should have 2 nameserver groups") + + // Validate NS Group 1 + nsg1, exists := retrievedAccount.NameServerGroups[nsGroupID1] + require.True(t, exists, "NS Group 1 should exist") + assert.Equal(t, "NS Group 1", nsg1.Name, "NS Group 1 name mismatch") + assert.Equal(t, "Primary nameservers", nsg1.Description, "NS Group 1 description mismatch") + assert.True(t, nsg1.Enabled, "NS Group 1 should be enabled") + assert.True(t, nsg1.Primary, "NS Group 1 should be primary") + assert.True(t, nsg1.SearchDomainsEnabled, "NS Group 1 search domains should be enabled") + assert.Empty(t, nsg1.AccountID, "NS Group 1 AccountID should be cleared") + + // Validate NS Group 1 NameServers + require.Len(t, nsg1.NameServers, 2, "NS Group 1 should have 2 nameservers") + assert.Equal(t, netip.MustParseAddr("8.8.8.8"), nsg1.NameServers[0].IP, "NS Group 1 nameserver 1 IP mismatch") + assert.Equal(t, nbdns.UDPNameServerType, nsg1.NameServers[0].NSType, "NS Group 1 nameserver 1 type mismatch") + assert.Equal(t, 53, nsg1.NameServers[0].Port, "NS Group 1 nameserver 1 port mismatch") + assert.Equal(t, netip.MustParseAddr("8.8.4.4"), nsg1.NameServers[1].IP, "NS Group 1 nameserver 2 IP mismatch") + + // Validate NS Group 1 Groups and Domains + assert.Len(t, nsg1.Groups, 2, "NS Group 1 groups length mismatch") + assert.Contains(t, nsg1.Groups, groupID1, "NS Group 1 should have group1") + assert.Contains(t, nsg1.Groups, groupID2, "NS Group 1 should have group2") + assert.Len(t, nsg1.Domains, 2, "NS Group 1 domains length mismatch") + assert.Contains(t, nsg1.Domains, "example.com", "NS Group 1 should have example.com domain") + assert.Contains(t, nsg1.Domains, "test.com", "NS Group 1 should have test.com domain") + + // Validate NS Group 2 + nsg2, exists := retrievedAccount.NameServerGroups[nsGroupID2] + require.True(t, exists, "NS Group 2 should exist") + assert.Equal(t, "NS Group 2", nsg2.Name, "NS Group 2 name mismatch") + assert.False(t, nsg2.Enabled, "NS Group 2 should be disabled") + assert.False(t, nsg2.Primary, "NS Group 2 should not be primary") + assert.False(t, nsg2.SearchDomainsEnabled, "NS Group 2 search domains should be disabled") + assert.Len(t, nsg2.NameServers, 1, "NS Group 2 should have 1 nameserver") + assert.Len(t, nsg2.Groups, 0, "NS Group 2 should have empty groups") + assert.Len(t, nsg2.Domains, 0, "NS Group 2 should have empty domains") + }) + + // ========== VALIDATE POSTURE CHECKS ========== + t.Run("PostureChecks", func(t *testing.T) { + require.Len(t, retrievedAccount.PostureChecks, 2, "Should have 2 posture checks") + + // Find posture checks by ID + var pc1, pc2 *posture.Checks + for _, pc := range retrievedAccount.PostureChecks { + if pc.ID == postureCheckID1 { + pc1 = pc + } else if pc.ID == postureCheckID2 { + pc2 = pc + } + } + + // Validate Posture Check 1 + require.NotNil(t, pc1, "Posture check 1 should exist") + assert.Equal(t, "Posture Check 1", pc1.Name, "Posture check 1 name mismatch") + assert.Equal(t, "OS version check", pc1.Description, "Posture check 1 description mismatch") + + // Validate NB Version Check + require.NotNil(t, pc1.Checks.NBVersionCheck, "NB version check should not be nil") + assert.Equal(t, "0.24.0", pc1.Checks.NBVersionCheck.MinVersion, "NB version check min version mismatch") + + // Validate OS Version Check + require.NotNil(t, pc1.Checks.OSVersionCheck, "OS version check should not be nil") + require.NotNil(t, pc1.Checks.OSVersionCheck.Ios, "iOS version check should not be nil") + assert.Equal(t, "16.0", pc1.Checks.OSVersionCheck.Ios.MinVersion, "iOS min version mismatch") + require.NotNil(t, pc1.Checks.OSVersionCheck.Darwin, "Darwin version check should not be nil") + assert.Equal(t, "22.0.0", pc1.Checks.OSVersionCheck.Darwin.MinVersion, "Darwin min version mismatch") + + // Validate Posture Check 2 + require.NotNil(t, pc2, "Posture check 2 should exist") + assert.Equal(t, "Posture Check 2", pc2.Name, "Posture check 2 name mismatch") + + // Validate Geo Location Check + require.NotNil(t, pc2.Checks.GeoLocationCheck, "Geo location check should not be nil") + assert.Equal(t, "allow", pc2.Checks.GeoLocationCheck.Action, "Geo location action mismatch") + assert.Len(t, pc2.Checks.GeoLocationCheck.Locations, 2, "Geo location check should have 2 locations") + assert.Equal(t, "US", pc2.Checks.GeoLocationCheck.Locations[0].CountryCode, "Location 1 country code mismatch") + assert.Equal(t, "San Francisco", pc2.Checks.GeoLocationCheck.Locations[0].CityName, "Location 1 city name mismatch") + assert.Equal(t, "GB", pc2.Checks.GeoLocationCheck.Locations[1].CountryCode, "Location 2 country code mismatch") + assert.Equal(t, "London", pc2.Checks.GeoLocationCheck.Locations[1].CityName, "Location 2 city name mismatch") + + // Validate Peer Network Range Check + require.NotNil(t, pc2.Checks.PeerNetworkRangeCheck, "Peer network range check should not be nil") + assert.Equal(t, "allow", pc2.Checks.PeerNetworkRangeCheck.Action, "Peer network range action mismatch") + assert.Len(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, 2, "Peer network range check should have 2 ranges") + assert.Contains(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, netip.MustParsePrefix("192.168.0.0/16"), "Should have 192.168.0.0/16 range") + assert.Contains(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, netip.MustParsePrefix("10.0.0.0/8"), "Should have 10.0.0.0/8 range") + }) + + // ========== VALIDATE NETWORKS ========== + t.Run("Networks", func(t *testing.T) { + require.Len(t, retrievedAccount.Networks, 1, "Should have 1 network") + + net1 := retrievedAccount.Networks[0] + assert.Equal(t, networkID1, net1.ID, "Network 1 ID mismatch") + assert.Equal(t, "Network 1", net1.Name, "Network 1 name mismatch") + assert.Equal(t, "Primary network", net1.Description, "Network 1 description mismatch") + }) + + // ========== VALIDATE NETWORK ROUTERS ========== + t.Run("NetworkRouters", func(t *testing.T) { + require.Len(t, retrievedAccount.NetworkRouters, 1, "Should have 1 network router") + + router := retrievedAccount.NetworkRouters[0] + assert.Equal(t, routerID1, router.ID, "Router 1 ID mismatch") + assert.Equal(t, networkID1, router.NetworkID, "Router 1 network ID mismatch") + assert.Equal(t, peerID1, router.Peer, "Router 1 peer mismatch") + assert.Empty(t, router.PeerGroups, "Router 1 peer groups should be empty") + assert.True(t, router.Masquerade, "Router 1 masquerade should be enabled") + assert.Equal(t, 100, router.Metric, "Router 1 metric mismatch") + }) + + // ========== VALIDATE NETWORK RESOURCES ========== + t.Run("NetworkResources", func(t *testing.T) { + require.Len(t, retrievedAccount.NetworkResources, 1, "Should have 1 network resource") + + res := retrievedAccount.NetworkResources[0] + assert.Equal(t, resourceID1, res.ID, "Resource 1 ID mismatch") + assert.Equal(t, networkID1, res.NetworkID, "Resource 1 network ID mismatch") + assert.Equal(t, "Resource 1", res.Name, "Resource 1 name mismatch") + assert.Equal(t, "Web server", res.Description, "Resource 1 description mismatch") + assert.Equal(t, netip.MustParsePrefix("192.168.1.100/32"), res.Prefix, "Resource 1 prefix mismatch") + assert.Equal(t, resourceTypes.Host, res.Type, "Resource 1 type mismatch") + }) + + // ========== VALIDATE ONBOARDING ========== + t.Run("Onboarding", func(t *testing.T) { + assert.Equal(t, accountID, retrievedAccount.Onboarding.AccountID, "Onboarding account ID mismatch") + assert.True(t, retrievedAccount.Onboarding.OnboardingFlowPending, "Onboarding flow should be pending") + assert.False(t, retrievedAccount.Onboarding.SignupFormPending, "Signup form should not be pending") + assert.WithinDuration(t, now, retrievedAccount.Onboarding.CreatedAt, time.Second, "Onboarding created at mismatch") + }) + + t.Log("✅ All comprehensive account field validations passed!") +} diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go new file mode 100644 index 000000000..10c6385ed --- /dev/null +++ b/management/server/store/sqlstore_bench_test.go @@ -0,0 +1,943 @@ +package store + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "sort" + "sync" + "testing" + "time" + + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/jackc/pgx/v5/pgxpool" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/status" +) + +func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types.Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() + + var account types.Account + result := s.db.Model(&account). + Omit("GroupsG"). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload(clause.Associations). + Take(&account, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us + for i, policy := range account.Policies { + var rules []*types.PolicyRule + err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + if err != nil { + return nil, status.Errorf(status.NotFound, "rule not found") + } + account.Policies[i].Rules = rules + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + account.SetupKeys[key.Key] = key.Copy() + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = peer.Copy() + } + account.PeersG = nil + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + user.PATs[pat.ID] = pat.Copy() + } + account.Users[user.Id] = user.Copy() + } + account.UsersG = nil + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + account.Groups[group.ID] = group.Copy() + } + account.GroupsG = nil + + var groupPeers []types.GroupPeer + s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). + Find(&groupPeers) + for _, groupPeer := range groupPeers { + if group, ok := account.Groups[groupPeer.GroupID]; ok { + group.Peers = append(group.Peers, groupPeer.PeerID) + } else { + log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) + } + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = route.Copy() + } + account.RoutesG = nil + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + account.NameServerGroups[ns.ID] = ns.Copy() + } + account.NameServerGroupsG = nil + + return &account, nil +} + +func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*types.Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() + + var account types.Account + result := s.db.Model(&account). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload("Policies.Rules"). + Preload("SetupKeysG"). + Preload("PeersG"). + Preload("UsersG"). + Preload("GroupsG.GroupPeers"). + Preload("RoutesG"). + Preload("NameServerGroupsG"). + Preload("PostureChecks"). + Preload("Networks"). + Preload("NetworkRouters"). + Preload("NetworkResources"). + Preload("Onboarding"). + Take(&account, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } + account.SetupKeys[key.Key] = &key + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = &peer + } + account.PeersG = nil + account.Users = make(map[string]*types.User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + pat.UserID = "" + user.PATs[pat.ID] = &pat + } + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } + account.Users[user.Id] = &user + user.PATsG = nil + } + account.UsersG = nil + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + group.Peers = make([]string, len(group.GroupPeers)) + for i, gp := range group.GroupPeers { + group.Peers[i] = gp.PeerID + } + if group.Resources == nil { + group.Resources = []types.Resource{} + } + account.Groups[group.ID] = group + } + account.GroupsG = nil + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = &route + } + account.RoutesG = nil + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } + account.NameServerGroups[ns.ID] = &ns + } + account.NameServerGroupsG = nil + return &account, nil +} + +func connectDBforTest(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = 12 + config.MinConns = 2 + config.MaxConnLifetime = time.Hour + config.HealthCheckPeriod = time.Minute + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + + fmt.Println("Successfully connected to the database!") + return pool, nil +} + +func setupBenchmarkDB(b testing.TB) (*SqlStore, string) { + dsn := "host=localhost user=postgres password=mysecretpassword dbname=testdb port=5432 sslmode=disable TimeZone=Europe/Berlin" + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + if err != nil { + b.Fatalf("failed to connect database: %v", err) + } + + pool, err := connectDBforTest(context.Background(), dsn) + if err != nil { + b.Fatalf("failed to connect database: %v", err) + } + + models := []interface{}{ + &types.Account{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, + &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, + &types.Policy{}, &types.PolicyRule{}, &route.Route{}, + &nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{}, + &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, + &types.AccountOnboarding{}, + } + + for i := len(models) - 1; i >= 0; i-- { + db.Migrator().DropTable(models[i]) + } + + err = db.AutoMigrate(models...) + if err != nil { + b.Fatalf("failed to migrate database: %v", err) + } + + store := &SqlStore{ + db: db, + pool: pool, + } + + const ( + accountID = "benchmark-account-id" + numUsers = 20 + numPatsPerUser = 3 + numSetupKeys = 25 + numPeers = 200 + numGroups = 30 + numPolicies = 50 + numRulesPerPolicy = 10 + numRoutes = 40 + numNSGroups = 10 + numPostureChecks = 15 + numNetworks = 5 + numNetworkRouters = 5 + numNetworkResources = 10 + ) + + _, ipNet, _ := net.ParseCIDR("100.64.0.0/10") + acc := types.Account{ + Id: accountID, + CreatedBy: "benchmark-user", + CreatedAt: time.Now(), + Domain: "benchmark.com", + IsDomainPrimaryAccount: true, + Network: &types.Network{ + Identifier: "benchmark-net", + Net: *ipNet, + Serial: 1, + }, + DNSSettings: types.DNSSettings{ + DisabledManagementGroups: []string{"group-disabled-1"}, + }, + Settings: &types.Settings{}, + } + if err := db.Create(&acc).Error; err != nil { + b.Fatalf("create account: %v", err) + } + + var setupKeys []types.SetupKey + for i := 0; i < numSetupKeys; i++ { + setupKeys = append(setupKeys, types.SetupKey{ + Id: fmt.Sprintf("keyid-%d", i), + AccountID: accountID, + Key: fmt.Sprintf("key-%d", i), + Name: fmt.Sprintf("Benchmark Key %d", i), + ExpiresAt: &time.Time{}, + }) + } + if err := db.Create(&setupKeys).Error; err != nil { + b.Fatalf("create setup keys: %v", err) + } + + var peers []nbpeer.Peer + for i := 0; i < numPeers; i++ { + peers = append(peers, nbpeer.Peer{ + ID: fmt.Sprintf("peer-%d", i), + AccountID: accountID, + Key: fmt.Sprintf("peerkey-%d", i), + IP: net.ParseIP(fmt.Sprintf("100.64.0.%d", i+1)), + Name: fmt.Sprintf("peer-name-%d", i), + Status: &nbpeer.PeerStatus{Connected: i%2 == 0, LastSeen: time.Now()}, + }) + } + if err := db.Create(&peers).Error; err != nil { + b.Fatalf("create peers: %v", err) + } + + for i := 0; i < numUsers; i++ { + userID := fmt.Sprintf("user-%d", i) + user := types.User{Id: userID, AccountID: accountID} + if err := db.Create(&user).Error; err != nil { + b.Fatalf("create user %s: %v", userID, err) + } + + var pats []types.PersonalAccessToken + for j := 0; j < numPatsPerUser; j++ { + pats = append(pats, types.PersonalAccessToken{ + ID: fmt.Sprintf("pat-%d-%d", i, j), + UserID: userID, + Name: fmt.Sprintf("PAT %d for User %d", j, i), + }) + } + if err := db.Create(&pats).Error; err != nil { + b.Fatalf("create pats for user %s: %v", userID, err) + } + } + + var groups []*types.Group + for i := 0; i < numGroups; i++ { + groups = append(groups, &types.Group{ + ID: fmt.Sprintf("group-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Group %d", i), + }) + } + if err := db.Create(&groups).Error; err != nil { + b.Fatalf("create groups: %v", err) + } + + for i := 0; i < numPolicies; i++ { + policyID := fmt.Sprintf("policy-%d", i) + policy := types.Policy{ID: policyID, AccountID: accountID, Name: fmt.Sprintf("Policy %d", i), Enabled: true} + if err := db.Create(&policy).Error; err != nil { + b.Fatalf("create policy %s: %v", policyID, err) + } + + var rules []*types.PolicyRule + for j := 0; j < numRulesPerPolicy; j++ { + rules = append(rules, &types.PolicyRule{ + ID: fmt.Sprintf("rule-%d-%d", i, j), + PolicyID: policyID, + Name: fmt.Sprintf("Rule %d for Policy %d", j, i), + Enabled: true, + Protocol: "all", + }) + } + if err := db.Create(&rules).Error; err != nil { + b.Fatalf("create rules for policy %s: %v", policyID, err) + } + } + + var routes []route.Route + for i := 0; i < numRoutes; i++ { + routes = append(routes, route.Route{ + ID: route.ID(fmt.Sprintf("route-%d", i)), + AccountID: accountID, + Description: fmt.Sprintf("Route %d", i), + Network: netip.MustParsePrefix(fmt.Sprintf("192.168.%d.0/24", i)), + Enabled: true, + }) + } + if err := db.Create(&routes).Error; err != nil { + b.Fatalf("create routes: %v", err) + } + + var nsGroups []nbdns.NameServerGroup + for i := 0; i < numNSGroups; i++ { + nsGroups = append(nsGroups, nbdns.NameServerGroup{ + ID: fmt.Sprintf("nsg-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("NS Group %d", i), + Description: "Benchmark NS Group", + Enabled: true, + }) + } + if err := db.Create(&nsGroups).Error; err != nil { + b.Fatalf("create nsgroups: %v", err) + } + + var postureChecks []*posture.Checks + for i := 0; i < numPostureChecks; i++ { + postureChecks = append(postureChecks, &posture.Checks{ + ID: fmt.Sprintf("pc-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Posture Check %d", i), + }) + } + if err := db.Create(&postureChecks).Error; err != nil { + b.Fatalf("create posture checks: %v", err) + } + + var networks []*networkTypes.Network + for i := 0; i < numNetworks; i++ { + networks = append(networks, &networkTypes.Network{ + ID: fmt.Sprintf("nettype-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Network Type %d", i), + }) + } + if err := db.Create(&networks).Error; err != nil { + b.Fatalf("create networks: %v", err) + } + + var networkRouters []*routerTypes.NetworkRouter + for i := 0; i < numNetworkRouters; i++ { + networkRouters = append(networkRouters, &routerTypes.NetworkRouter{ + ID: fmt.Sprintf("router-%d", i), + AccountID: accountID, + NetworkID: networks[i%numNetworks].ID, + Peer: peers[i%numPeers].ID, + }) + } + if err := db.Create(&networkRouters).Error; err != nil { + b.Fatalf("create network routers: %v", err) + } + + var networkResources []*resourceTypes.NetworkResource + for i := 0; i < numNetworkResources; i++ { + networkResources = append(networkResources, &resourceTypes.NetworkResource{ + ID: fmt.Sprintf("resource-%d", i), + AccountID: accountID, + NetworkID: networks[i%numNetworks].ID, + Name: fmt.Sprintf("Resource %d", i), + }) + } + if err := db.Create(&networkResources).Error; err != nil { + b.Fatalf("create network resources: %v", err) + } + + onboarding := types.AccountOnboarding{ + AccountID: accountID, + OnboardingFlowPending: true, + } + if err := db.Create(&onboarding).Error; err != nil { + b.Fatalf("create onboarding: %v", err) + } + + return store, accountID +} + +func BenchmarkGetAccount(b *testing.B) { + store, accountID := setupBenchmarkDB(b) + ctx := context.Background() + b.ResetTimer() + b.ReportAllocs() + b.Run("old", func(b *testing.B) { + for range b.N { + _, err := store.GetAccountSlow(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountSlow failed: %v", err) + } + } + }) + b.Run("gorm opt", func(b *testing.B) { + for range b.N { + _, err := store.GetAccountGormOpt(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountFast failed: %v", err) + } + } + }) + b.Run("raw", func(b *testing.B) { + for range b.N { + _, err := store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountPureSQL failed: %v", err) + } + } + }) + store.pool.Close() +} + +func TestAccountEquivalence(t *testing.T) { + store, accountID := setupBenchmarkDB(t) + ctx := context.Background() + + type getAccountFunc func(context.Context, string) (*types.Account, error) + + tests := []struct { + name string + expectedF getAccountFunc + actualF getAccountFunc + }{ + {"old vs new", store.GetAccountSlow, store.GetAccountGormOpt}, + {"old vs raw", store.GetAccountSlow, store.GetAccount}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expected, errOld := tt.expectedF(ctx, accountID) + assert.NoError(t, errOld, "expected function should not return an error") + assert.NotNil(t, expected, "expected should not be nil") + + actual, errNew := tt.actualF(ctx, accountID) + assert.NoError(t, errNew, "actual function should not return an error") + assert.NotNil(t, actual, "actual should not be nil") + testAccountEquivalence(t, expected, actual) + }) + } + + expected, errOld := store.GetAccountSlow(ctx, accountID) + assert.NoError(t, errOld, "GetAccountSlow should not return an error") + assert.NotNil(t, expected, "expected should not be nil") + + actual, errNew := store.GetAccount(ctx, accountID) + assert.NoError(t, errNew, "GetAccount (new) should not return an error") + assert.NotNil(t, actual, "actual should not be nil") +} + +func testAccountEquivalence(t *testing.T, expected, actual *types.Account) { + assert.Equal(t, expected.Id, actual.Id, "Account IDs should be equal") + assert.Equal(t, expected.CreatedBy, actual.CreatedBy, "Account CreatedBy fields should be equal") + assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second, "Account CreatedAt timestamps should be within a second") + assert.Equal(t, expected.Domain, actual.Domain, "Account Domains should be equal") + assert.Equal(t, expected.DomainCategory, actual.DomainCategory, "Account DomainCategories should be equal") + assert.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount, "Account IsDomainPrimaryAccount flags should be equal") + assert.Equal(t, expected.Network, actual.Network, "Embedded Account Network structs should be equal") + assert.Equal(t, expected.DNSSettings, actual.DNSSettings, "Embedded Account DNSSettings structs should be equal") + assert.Equal(t, expected.Onboarding, actual.Onboarding, "Embedded Account Onboarding structs should be equal") + + assert.Len(t, actual.SetupKeys, len(expected.SetupKeys), "SetupKeys maps should have the same number of elements") + for key, oldVal := range expected.SetupKeys { + newVal, ok := actual.SetupKeys[key] + assert.True(t, ok, "SetupKey with key '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "SetupKey with key '%s' should be equal", key) + } + + assert.Len(t, actual.Peers, len(expected.Peers), "Peers maps should have the same number of elements") + for key, oldVal := range expected.Peers { + newVal, ok := actual.Peers[key] + assert.True(t, ok, "Peer with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "Peer with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Users, len(expected.Users), "Users maps should have the same number of elements") + for key, oldUser := range expected.Users { + newUser, ok := actual.Users[key] + assert.True(t, ok, "User with ID '%s' should exist in new account", key) + + assert.Len(t, newUser.PATs, len(oldUser.PATs), "PATs map for user '%s' should have the same size", key) + for patKey, oldPAT := range oldUser.PATs { + newPAT, patOk := newUser.PATs[patKey] + assert.True(t, patOk, "PAT with ID '%s' for user '%s' should exist in new user object", patKey, key) + assert.Equal(t, *oldPAT, *newPAT, "PAT with ID '%s' for user '%s' should be equal", patKey, key) + } + + oldUser.PATs = nil + newUser.PATs = nil + assert.Equal(t, *oldUser, *newUser, "User struct for ID '%s' (without PATs) should be equal", key) + } + + assert.Len(t, actual.Groups, len(expected.Groups), "Groups maps should have the same number of elements") + for key, oldVal := range expected.Groups { + newVal, ok := actual.Groups[key] + assert.True(t, ok, "Group with ID '%s' should exist in new account", key) + sort.Strings(oldVal.Peers) + sort.Strings(newVal.Peers) + assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Routes, len(expected.Routes), "Routes maps should have the same number of elements") + for key, oldVal := range expected.Routes { + newVal, ok := actual.Routes[key] + assert.True(t, ok, "Route with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "Route with ID '%s' should be equal", key) + } + + assert.Len(t, actual.NameServerGroups, len(expected.NameServerGroups), "NameServerGroups maps should have the same number of elements") + for key, oldVal := range expected.NameServerGroups { + newVal, ok := actual.NameServerGroups[key] + assert.True(t, ok, "NameServerGroup with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "NameServerGroup with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Policies, len(expected.Policies), "Policies slices should have the same number of elements") + sort.Slice(expected.Policies, func(i, j int) bool { return expected.Policies[i].ID < expected.Policies[j].ID }) + sort.Slice(actual.Policies, func(i, j int) bool { return actual.Policies[i].ID < actual.Policies[j].ID }) + for i := range expected.Policies { + sort.Slice(expected.Policies[i].Rules, func(j, k int) bool { return expected.Policies[i].Rules[j].ID < expected.Policies[i].Rules[k].ID }) + sort.Slice(actual.Policies[i].Rules, func(j, k int) bool { return actual.Policies[i].Rules[j].ID < actual.Policies[i].Rules[k].ID }) + assert.Equal(t, *expected.Policies[i], *actual.Policies[i], "Policy with ID '%s' should be equal", expected.Policies[i].ID) + } + + assert.Len(t, actual.PostureChecks, len(expected.PostureChecks), "PostureChecks slices should have the same number of elements") + sort.Slice(expected.PostureChecks, func(i, j int) bool { return expected.PostureChecks[i].ID < expected.PostureChecks[j].ID }) + sort.Slice(actual.PostureChecks, func(i, j int) bool { return actual.PostureChecks[i].ID < actual.PostureChecks[j].ID }) + for i := range expected.PostureChecks { + assert.Equal(t, *expected.PostureChecks[i], *actual.PostureChecks[i], "PostureCheck with ID '%s' should be equal", expected.PostureChecks[i].ID) + } + + assert.Len(t, actual.Networks, len(expected.Networks), "Networks slices should have the same number of elements") + sort.Slice(expected.Networks, func(i, j int) bool { return expected.Networks[i].ID < expected.Networks[j].ID }) + sort.Slice(actual.Networks, func(i, j int) bool { return actual.Networks[i].ID < actual.Networks[j].ID }) + for i := range expected.Networks { + assert.Equal(t, *expected.Networks[i], *actual.Networks[i], "Network with ID '%s' should be equal", expected.Networks[i].ID) + } + + assert.Len(t, actual.NetworkRouters, len(expected.NetworkRouters), "NetworkRouters slices should have the same number of elements") + sort.Slice(expected.NetworkRouters, func(i, j int) bool { return expected.NetworkRouters[i].ID < expected.NetworkRouters[j].ID }) + sort.Slice(actual.NetworkRouters, func(i, j int) bool { return actual.NetworkRouters[i].ID < actual.NetworkRouters[j].ID }) + for i := range expected.NetworkRouters { + assert.Equal(t, *expected.NetworkRouters[i], *actual.NetworkRouters[i], "NetworkRouter with ID '%s' should be equal", expected.NetworkRouters[i].ID) + } + + assert.Len(t, actual.NetworkResources, len(expected.NetworkResources), "NetworkResources slices should have the same number of elements") + sort.Slice(expected.NetworkResources, func(i, j int) bool { return expected.NetworkResources[i].ID < expected.NetworkResources[j].ID }) + sort.Slice(actual.NetworkResources, func(i, j int) bool { return actual.NetworkResources[i].ID < actual.NetworkResources[j].ID }) + for i := range expected.NetworkResources { + assert.Equal(t, *expected.NetworkResources[i], *actual.NetworkResources[i], "NetworkResource with ID '%s' should be equal", expected.NetworkResources[i].ID) + } +} + +func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) { + account, err := s.getAccount(ctx, accountID) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + errChan := make(chan error, 12) + + wg.Add(1) + go func() { + defer wg.Done() + keys, err := s.getSetupKeys(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + peers, err := s.getPeers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + users, err := s.getUsers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + groups, err := s.getGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + policies, err := s.getPolicies(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + routes, err := s.getRoutes(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + nsgs, err := s.getNameServerGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + checks, err := s.getPostureChecks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + networks, err := s.getNetworks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Networks = networks + }() + + wg.Add(1) + go func() { + defer wg.Done() + routers, err := s.getNetworkRouters(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + resources, err := s.getNetworkResources(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.getAccountOnboarding(ctx, accountID, account) + if err != nil { + errChan <- err + return + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + var err error + pats, err = s.getPersonalAccessTokens(ctx, userIDs) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + var err error + rules, err = s.getPolicyRules(ctx, policyIDs) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + var err error + groupPeers, err = s.getGroupPeers(ctx, groupIDs) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" + } + + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) + } + + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key + } + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat + } + } + account.Users[user.Id] = user + } + + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules + } + } + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs + } + account.Groups[group.ID] = group + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route + } + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg + } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil + account.NameServerGroupsG = nil + + return account, nil +} diff --git a/route/route.go b/route/route.go index 08a2d37dc..c724e7c7d 100644 --- a/route/route.go +++ b/route/route.go @@ -124,6 +124,7 @@ func (r *Route) EventMeta() map[string]any { func (r *Route) Copy() *Route { route := &Route{ ID: r.ID, + AccountID: r.AccountID, Description: r.Description, NetID: r.NetID, Network: r.Network,