From 5e79cc0176148d55bc8f1ecc95ab15d9f1c0f24d Mon Sep 17 00:00:00 2001 From: crn4 Date: Thu, 16 Oct 2025 23:54:17 +0200 Subject: [PATCH 01/14] new raw sql get account method --- management/server/store/sql_store.go | 668 +++++++++++++++++++++++++-- route/route.go | 1 + 2 files changed, 641 insertions(+), 28 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 382d026c8..de1ea1a07 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -15,6 +15,8 @@ import ( "sync" "time" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -55,6 +57,7 @@ type SqlStore struct { metrics telemetry.AppMetrics installationPK int storeEngine types.Engine + pool *pgxpool.Pool } type installation struct { @@ -774,6 +777,560 @@ func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types. } func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { + var account types.Account + account.Network = &types.Network{} + const accountQuery = ` + SELECT + id, created_by, created_at, domain, domain_category, is_domain_primary_account, + network_identifier, network_net, network_dns, network_serial, + dns_settings_disabled_management_groups + FROM accounts WHERE id = $1` + + var networkNet, dnsSettingsDisabledGroups []byte + err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan( + &account.Id, &account.CreatedBy, &account.CreatedAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount, + &account.Network.Identifier, &networkNet, &account.Network.Dns, &account.Network.Serial, + &dnsSettingsDisabledGroups, + ) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, errors.New("account not found") + } + return nil, err + } + _ = json.Unmarshal(networkNet, &account.Network.Net) + _ = json.Unmarshal(dnsSettingsDisabledGroups, &account.DNSSettings.DisabledManagementGroups) + + var wg sync.WaitGroup + errChan := make(chan error, 12) + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + + keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { + var sk types.SetupKey + var autoGroups []byte + err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &sk.ExpiresAt, &sk.UpdatedAt, &sk.Revoked, &sk.UsedTimes, &sk.LastUsed, &autoGroups, &sk.UsageLimit, &sk.Ephemeral, &sk.AllowExtraDNSLabels) + if err == nil && autoGroups != nil { + _ = json.Unmarshal(autoGroups, &sk.AutoGroups) + } + if sk.UpdatedAt.IsZero() { + sk.UpdatedAt = sk.CreatedAt + } + return sk, err + }) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, location_geo_name_id FROM peers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + + peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { + var p nbpeer.Peer + p.Status = &nbpeer.PeerStatus{} + var ip, extraDNS, netAddr, env, flags, files, connIP []byte + err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &p.SSHEnabled, &p.LoginExpirationEnabled, &p.InactivityExpirationEnabled, &p.LastLogin, &p.CreatedAt, &p.Ephemeral, &extraDNS, &p.AllowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &p.Status.LastSeen, &p.Status.Connected, &p.Status.LoginExpired, &p.Status.RequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) + if err == nil { + if ip != nil { + _ = json.Unmarshal(ip, &p.IP) + } + if extraDNS != nil { + _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) + } + if netAddr != nil { + _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses) + } + if env != nil { + _ = json.Unmarshal(env, &p.Meta.Environment) + } + if flags != nil { + _ = json.Unmarshal(flags, &p.Meta.Flags) + } + if files != nil { + _ = json.Unmarshal(files, &p.Meta.Files) + } + if connIP != nil { + _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) + } + } + return p, err + }) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { + var u types.User + var autoGroups []byte + err := row.Scan(&u.Id, &u.AccountID, &u.Role, &u.IsServiceUser, &u.NonDeletable, &u.ServiceUserName, &autoGroups, &u.Blocked, &u.PendingApproval, &u.LastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) + if err == nil && autoGroups != nil { + _ = json.Unmarshal(autoGroups, &u.AutoGroups) + } + return u, err + }) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { + var g types.Group + var resources []byte + err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &g.IntegrationReference.ID, &g.IntegrationReference.IntegrationType) + if err == nil && resources != nil { + _ = json.Unmarshal(resources, &g.Resources) + } + return &g, err + }) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { + var p types.Policy + var checks []byte + err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &p.Enabled, &checks) + if err == nil && checks != nil { + _ = json.Unmarshal(checks, &p.SourcePostureChecks) + } + return &p, err + }) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { + var r route.Route + var network, domains, peerGroups, groups, accessGroups []byte + err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &r.KeepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &r.Masquerade, &r.Metric, &r.Enabled, &groups, &accessGroups, &r.SkipAutoApply) + if err == nil { + if network != nil { + _ = json.Unmarshal(network, &r.Network) + } + if domains != nil { + _ = json.Unmarshal(domains, &r.Domains) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + if groups != nil { + _ = json.Unmarshal(groups, &r.Groups) + } + if accessGroups != nil { + _ = json.Unmarshal(accessGroups, &r.AccessControlGroups) + } + } + return r, err + }) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { + var n nbdns.NameServerGroup + var ns, groups, domains []byte + err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &n.Primary, &domains, &n.Enabled, &n.SearchDomainsEnabled) + if err == nil { + if ns != nil { + _ = json.Unmarshal(ns, &n.NameServers) + } + if groups != nil { + _ = json.Unmarshal(groups, &n.Groups) + } + if domains != nil { + _ = json.Unmarshal(domains, &n.Domains) + } + } + return n, err + }) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { + var c posture.Checks + var checksDef []byte + err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) + if err == nil && checksDef != nil { + _ = json.Unmarshal(checksDef, &c.Checks) + } + return &c, err + }) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network]) + if err != nil { + errChan <- err + return + } + account.Networks = make([]*networkTypes.Network, len(networks)) + for i := range networks { + account.Networks[i] = &networks[i] + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*routerTypes.NetworkRouter, error) { + var r routerTypes.NetworkRouter + var peerGroups []byte + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &r.Masquerade, &r.Metric, &r.Enabled) + if err == nil && peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + return &r, err + }) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*resourceTypes.NetworkResource, error) { + var r resourceTypes.NetworkResource + var prefix []byte + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &r.Enabled) + if err == nil && prefix != nil { + _ = json.Unmarshal(prefix, &r.Prefix) + } + return &r, err + }) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` + err := s.pool.QueryRow(ctx, query, accountID).Scan( + &account.Onboarding.AccountID, + &account.Onboarding.OnboardingFlowPending, + &account.Onboarding.SignupFormPending, + &account.Onboarding.CreatedAt, + &account.Onboarding.UpdatedAt, + ) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + if len(userIDs) == 0 { + return + } + const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, userIDs) + if err != nil { + errChan <- err + return + } + pats, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.PersonalAccessToken]) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + if len(policyIDs) == 0 { + return + } + const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, policyIDs) + if err != nil { + errChan <- err + return + } + rules, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { + var r types.PolicyRule + var dest, destRes, sources, sourceRes, ports, portRanges []byte + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &r.Enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &r.Bidirectional, &r.Protocol, &ports, &portRanges) + if err == nil { + if dest != nil { + _ = json.Unmarshal(dest, &r.Destinations) + } + if destRes != nil { + _ = json.Unmarshal(destRes, &r.DestinationResource) + } + if sources != nil { + _ = json.Unmarshal(sources, &r.Sources) + } + if sourceRes != nil { + _ = json.Unmarshal(sourceRes, &r.SourceResource) + } + if ports != nil { + _ = json.Unmarshal(ports, &r.Ports) + } + if portRanges != nil { + _ = json.Unmarshal(portRanges, &r.PortRanges) + } + } + return &r, err + }) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + if len(groupIDs) == 0 { + return + } + const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, groupIDs) + if err != nil { + errChan <- err + return + } + groupPeers, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer]) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" + } + + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) + } + + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key + } + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat + } + } + account.Users[user.Id] = user + } + + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules + } + } + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs + } + account.Groups[group.ID] = group + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route + } + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg + } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil + account.NameServerGroupsG = nil + + return &account, nil +} + +func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { elapsed := time.Since(start) @@ -784,9 +1341,20 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc var account types.Account result := s.db.Model(&account). - Omit("GroupsG"). + // Omit("GroupsG"). Preload("UsersG.PATsG"). // have to be specifies as this is nester reference - Preload(clause.Associations). + Preload("Policies.Rules"). + Preload("SetupKeysG"). + Preload("PeersG"). + Preload("UsersG"). + Preload("GroupsG.GroupPeers"). + Preload("RoutesG"). + Preload("NameServerGroupsG"). + Preload("PostureChecks"). + Preload("Networks"). + Preload("NetworkRouters"). + Preload("NetworkResources"). + Preload("Onboarding"). Take(&account, idQueryCondition, accountID) if result.Error != nil { log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) @@ -797,24 +1365,27 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc } // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us - for i, policy := range account.Policies { - var rules []*types.PolicyRule - err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error - if err != nil { - return nil, status.Errorf(status.NotFound, "rule not found") - } - account.Policies[i].Rules = rules - } + // for i, policy := range account.Policies { + // var rules []*types.PolicyRule + // err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + // if err != nil { + // return nil, status.Errorf(status.NotFound, "rule not found") + // } + // account.Policies[i].Rules = rules + // } account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) for _, key := range account.SetupKeysG { - account.SetupKeys[key.Key] = key.Copy() + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + account.SetupKeys[key.Key] = &key } account.SetupKeysG = nil account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) for _, peer := range account.PeersG { - account.Peers[peer.ID] = peer.Copy() + account.Peers[peer.ID] = &peer } account.PeersG = nil @@ -822,38 +1393,45 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc for _, user := range account.UsersG { user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) for _, pat := range user.PATsG { - user.PATs[pat.ID] = pat.Copy() + pat.UserID = "" + user.PATs[pat.ID] = &pat } - account.Users[user.Id] = user.Copy() + account.Users[user.Id] = &user + user.PATsG = nil } account.UsersG = nil account.Groups = make(map[string]*types.Group, len(account.GroupsG)) for _, group := range account.GroupsG { - account.Groups[group.ID] = group.Copy() + group.Peers = make([]string, len(group.GroupPeers)) + for i, gp := range group.GroupPeers { + group.Peers[i] = gp.PeerID + } + account.Groups[group.ID] = group } account.GroupsG = nil - var groupPeers []types.GroupPeer - s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). - Find(&groupPeers) - for _, groupPeer := range groupPeers { - if group, ok := account.Groups[groupPeer.GroupID]; ok { - group.Peers = append(group.Peers, groupPeer.PeerID) - } else { - log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) - } - } + // var groupPeers []types.GroupPeer + // s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). + // Find(&groupPeers) + // for _, groupPeer := range groupPeers { + // if group, ok := account.Groups[groupPeer.GroupID]; ok { + // group.Peers = append(group.Peers, groupPeer.PeerID) + // } else { + // log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) + // } + // } account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) for _, route := range account.RoutesG { - account.Routes[route.ID] = route.Copy() + account.Routes[route.ID] = &route } account.RoutesG = nil account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) for _, ns := range account.NameServerGroupsG { - account.NameServerGroups[ns.ID] = ns.Copy() + ns.AccountID = "" + account.NameServerGroups[ns.ID] = &ns } account.NameServerGroupsG = nil @@ -1199,8 +1777,42 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe if err != nil { return nil, err } + pool, err := connectDB(context.Background(), dsn) + if err != nil { + return nil, err + } + store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) + if err != nil { + pool.Close() + return nil, err + } + store.pool = pool + return store, nil +} - return NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) +func connectDB(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = 10 + config.MinConns = 2 + config.MaxConnLifetime = time.Hour + config.HealthCheckPeriod = time.Minute + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + + fmt.Println("Successfully connected to the database!") + return pool, nil } // NewMysqlStore creates a new MySQL store. diff --git a/route/route.go b/route/route.go index 08a2d37dc..c724e7c7d 100644 --- a/route/route.go +++ b/route/route.go @@ -124,6 +124,7 @@ func (r *Route) EventMeta() map[string]any { func (r *Route) Copy() *Route { route := &Route{ ID: r.ID, + AccountID: r.AccountID, Description: r.Description, NetID: r.NetID, Network: r.Network, From 23466adbae3a82cf96305ce8c4912c718e378267 Mon Sep 17 00:00:00 2001 From: crn4 Date: Fri, 17 Oct 2025 00:16:05 +0200 Subject: [PATCH 02/14] with test --- .../server/store/sqlstore_bench_test.go | 1323 +++++++++++++++++ 1 file changed, 1323 insertions(+) create mode 100644 management/server/store/sqlstore_bench_test.go diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go new file mode 100644 index 000000000..add3fffe5 --- /dev/null +++ b/management/server/store/sqlstore_bench_test.go @@ -0,0 +1,1323 @@ +package store + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "net/netip" + "sort" + "sync" + "testing" + "time" + + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" + "github.com/netbirdio/netbird/shared/management/status" +) + +func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types.Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() + + var account types.Account + result := s.db.Model(&account). + Omit("GroupsG"). + Preload("UsersG.PATsG"). // have to be specifies as this is nester reference + Preload(clause.Associations). + Take(&account, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us + for i, policy := range account.Policies { + var rules []*types.PolicyRule + err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + if err != nil { + return nil, status.Errorf(status.NotFound, "rule not found") + } + account.Policies[i].Rules = rules + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + account.SetupKeys[key.Key] = key.Copy() + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = peer.Copy() + } + account.PeersG = nil + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + user.PATs[pat.ID] = pat.Copy() + } + account.Users[user.Id] = user.Copy() + } + account.UsersG = nil + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + account.Groups[group.ID] = group.Copy() + } + account.GroupsG = nil + + var groupPeers []types.GroupPeer + s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). + Find(&groupPeers) + for _, groupPeer := range groupPeers { + if group, ok := account.Groups[groupPeer.GroupID]; ok { + group.Peers = append(group.Peers, groupPeer.PeerID) + } else { + log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) + } + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = route.Copy() + } + account.RoutesG = nil + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + account.NameServerGroups[ns.ID] = ns.Copy() + } + account.NameServerGroupsG = nil + + return &account, nil +} + +func connectDBforTest(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = 10 + config.MinConns = 2 + config.MaxConnLifetime = time.Hour + config.HealthCheckPeriod = time.Minute + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + + fmt.Println("Successfully connected to the database!") + return pool, nil +} + +func setupBenchmarkDB(b testing.TB) (*SqlStore, string) { + dsn := "host=localhost user=postgres password=mysecretpassword dbname=testdb port=5432 sslmode=disable TimeZone=Europe/Berlin" + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + if err != nil { + b.Fatalf("failed to connect database: %v", err) + } + + pool, err := connectDB(context.Background(), dsn) + if err != nil { + b.Fatalf("failed to connect database: %v", err) + } + + models := []interface{}{ + &types.Account{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, + &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, + &types.Policy{}, &types.PolicyRule{}, &route.Route{}, + &nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{}, + &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, + &types.AccountOnboarding{}, + } + + for i := len(models) - 1; i >= 0; i-- { + db.Migrator().DropTable(models[i]) + } + + err = db.AutoMigrate(models...) + if err != nil { + b.Fatalf("failed to migrate database: %v", err) + } + + store := &SqlStore{ + db: db, + pool: pool, + } + + const ( + accountID = "benchmark-account-id" + numUsers = 20 + numPatsPerUser = 3 + numSetupKeys = 25 + numPeers = 200 + numGroups = 30 + numPolicies = 50 + numRulesPerPolicy = 10 + numRoutes = 40 + numNSGroups = 10 + numPostureChecks = 15 + numNetworks = 5 + numNetworkRouters = 5 + numNetworkResources = 10 + ) + + _, ipNet, _ := net.ParseCIDR("100.64.0.0/10") + acc := types.Account{ + Id: accountID, + CreatedBy: "benchmark-user", + CreatedAt: time.Now(), + Domain: "benchmark.com", + IsDomainPrimaryAccount: true, + Network: &types.Network{ + Identifier: "benchmark-net", + Net: *ipNet, + Serial: 1, + }, + DNSSettings: types.DNSSettings{ + DisabledManagementGroups: []string{"group-disabled-1"}, + }, + Settings: &types.Settings{}, + } + if err := db.Create(&acc).Error; err != nil { + b.Fatalf("create account: %v", err) + } + + var setupKeys []types.SetupKey + for i := 0; i < numSetupKeys; i++ { + setupKeys = append(setupKeys, types.SetupKey{ + Id: fmt.Sprintf("keyid-%d", i), + AccountID: accountID, + Key: fmt.Sprintf("key-%d", i), + Name: fmt.Sprintf("Benchmark Key %d", i), + ExpiresAt: &time.Time{}, + }) + } + if err := db.Create(&setupKeys).Error; err != nil { + b.Fatalf("create setup keys: %v", err) + } + + var peers []nbpeer.Peer + for i := 0; i < numPeers; i++ { + peers = append(peers, nbpeer.Peer{ + ID: fmt.Sprintf("peer-%d", i), + AccountID: accountID, + Key: fmt.Sprintf("peerkey-%d", i), + IP: net.ParseIP(fmt.Sprintf("100.64.0.%d", i+1)), + Name: fmt.Sprintf("peer-name-%d", i), + Status: &nbpeer.PeerStatus{Connected: i%2 == 0, LastSeen: time.Now()}, + }) + } + if err := db.Create(&peers).Error; err != nil { + b.Fatalf("create peers: %v", err) + } + + for i := 0; i < numUsers; i++ { + userID := fmt.Sprintf("user-%d", i) + user := types.User{Id: userID, AccountID: accountID} + if err := db.Create(&user).Error; err != nil { + b.Fatalf("create user %s: %v", userID, err) + } + + var pats []types.PersonalAccessToken + for j := 0; j < numPatsPerUser; j++ { + pats = append(pats, types.PersonalAccessToken{ + ID: fmt.Sprintf("pat-%d-%d", i, j), + UserID: userID, + Name: fmt.Sprintf("PAT %d for User %d", j, i), + }) + } + if err := db.Create(&pats).Error; err != nil { + b.Fatalf("create pats for user %s: %v", userID, err) + } + } + + var groups []*types.Group + for i := 0; i < numGroups; i++ { + groups = append(groups, &types.Group{ + ID: fmt.Sprintf("group-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Group %d", i), + }) + } + if err := db.Create(&groups).Error; err != nil { + b.Fatalf("create groups: %v", err) + } + + for i := 0; i < numPolicies; i++ { + policyID := fmt.Sprintf("policy-%d", i) + policy := types.Policy{ID: policyID, AccountID: accountID, Name: fmt.Sprintf("Policy %d", i), Enabled: true} + if err := db.Create(&policy).Error; err != nil { + b.Fatalf("create policy %s: %v", policyID, err) + } + + var rules []*types.PolicyRule + for j := 0; j < numRulesPerPolicy; j++ { + rules = append(rules, &types.PolicyRule{ + ID: fmt.Sprintf("rule-%d-%d", i, j), + PolicyID: policyID, + Name: fmt.Sprintf("Rule %d for Policy %d", j, i), + Enabled: true, + Protocol: "all", + }) + } + if err := db.Create(&rules).Error; err != nil { + b.Fatalf("create rules for policy %s: %v", policyID, err) + } + } + + var routes []route.Route + for i := 0; i < numRoutes; i++ { + routes = append(routes, route.Route{ + ID: route.ID(fmt.Sprintf("route-%d", i)), + AccountID: accountID, + Description: fmt.Sprintf("Route %d", i), + Network: netip.MustParsePrefix(fmt.Sprintf("192.168.%d.0/24", i)), + Enabled: true, + }) + } + if err := db.Create(&routes).Error; err != nil { + b.Fatalf("create routes: %v", err) + } + + var nsGroups []nbdns.NameServerGroup + for i := 0; i < numNSGroups; i++ { + nsGroups = append(nsGroups, nbdns.NameServerGroup{ + ID: fmt.Sprintf("nsg-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("NS Group %d", i), + Description: "Benchmark NS Group", + Enabled: true, + }) + } + if err := db.Create(&nsGroups).Error; err != nil { + b.Fatalf("create nsgroups: %v", err) + } + + var postureChecks []*posture.Checks + for i := 0; i < numPostureChecks; i++ { + postureChecks = append(postureChecks, &posture.Checks{ + ID: fmt.Sprintf("pc-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Posture Check %d", i), + }) + } + if err := db.Create(&postureChecks).Error; err != nil { + b.Fatalf("create posture checks: %v", err) + } + + var networks []*networkTypes.Network + for i := 0; i < numNetworks; i++ { + networks = append(networks, &networkTypes.Network{ + ID: fmt.Sprintf("nettype-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Network Type %d", i), + }) + } + if err := db.Create(&networks).Error; err != nil { + b.Fatalf("create networks: %v", err) + } + + var networkRouters []*routerTypes.NetworkRouter + for i := 0; i < numNetworkRouters; i++ { + networkRouters = append(networkRouters, &routerTypes.NetworkRouter{ + ID: fmt.Sprintf("router-%d", i), + AccountID: accountID, + NetworkID: networks[i%numNetworks].ID, + Peer: peers[i%numPeers].ID, + }) + } + if err := db.Create(&networkRouters).Error; err != nil { + b.Fatalf("create network routers: %v", err) + } + + var networkResources []*resourceTypes.NetworkResource + for i := 0; i < numNetworkResources; i++ { + networkResources = append(networkResources, &resourceTypes.NetworkResource{ + ID: fmt.Sprintf("resource-%d", i), + AccountID: accountID, + NetworkID: networks[i%numNetworks].ID, + Name: fmt.Sprintf("Resource %d", i), + }) + } + if err := db.Create(&networkResources).Error; err != nil { + b.Fatalf("create network resources: %v", err) + } + + onboarding := types.AccountOnboarding{ + AccountID: accountID, + OnboardingFlowPending: true, + } + if err := db.Create(&onboarding).Error; err != nil { + b.Fatalf("create onboarding: %v", err) + } + + return store, accountID +} + +func BenchmarkGetAccount(b *testing.B) { + store, accountID := setupBenchmarkDB(b) + ctx := context.Background() + b.ResetTimer() + b.ReportAllocs() + b.Run("old", func(b *testing.B) { + for range b.N { + _, err := store.GetAccountSlow(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountSlow failed: %v", err) + } + } + }) + b.Run("new", func(b *testing.B) { + for range b.N { + _, err := store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountFast failed: %v", err) + } + } + }) + b.Run("raw", func(b *testing.B) { + for range b.N { + _, err := store.GetAccountPureSQL(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountPureSQL failed: %v", err) + } + } + }) + store.pool.Close() +} + +func TestAccountEquivalence(t *testing.T) { + store, accountID := setupBenchmarkDB(t) + ctx := context.Background() + + type getAccountFunc func(context.Context, string) (*types.Account, error) + + tests := []struct { + name string + expectedF getAccountFunc + actualF getAccountFunc + }{ + {"old vs new", store.GetAccountSlow, store.GetAccount}, + {"old vs raw", store.GetAccountSlow, store.GetAccountPureSQL}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expected, errOld := tt.expectedF(ctx, accountID) + assert.NoError(t, errOld, "expected function should not return an error") + assert.NotNil(t, expected, "expected should not be nil") + + actual, errNew := tt.actualF(ctx, accountID) + assert.NoError(t, errNew, "actual function should not return an error") + assert.NotNil(t, actual, "actual should not be nil") + testAccountEquivalence(t, expected, actual) + }) + } + + expected, errOld := store.GetAccountSlow(ctx, accountID) + assert.NoError(t, errOld, "GetAccountSlow should not return an error") + assert.NotNil(t, expected, "expected should not be nil") + + actual, errNew := store.GetAccount(ctx, accountID) + assert.NoError(t, errNew, "GetAccount (new) should not return an error") + assert.NotNil(t, actual, "actual should not be nil") +} + +func testAccountEquivalence(t *testing.T, expected, actual *types.Account) { + normalizeNilSlices(expected) + normalizeNilSlices(actual) + + assert.Equal(t, expected.Id, actual.Id, "Account IDs should be equal") + assert.Equal(t, expected.CreatedBy, actual.CreatedBy, "Account CreatedBy fields should be equal") + assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second, "Account CreatedAt timestamps should be within a second") + assert.Equal(t, expected.Domain, actual.Domain, "Account Domains should be equal") + assert.Equal(t, expected.DomainCategory, actual.DomainCategory, "Account DomainCategories should be equal") + assert.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount, "Account IsDomainPrimaryAccount flags should be equal") + assert.Equal(t, expected.Network, actual.Network, "Embedded Account Network structs should be equal") + assert.Equal(t, expected.DNSSettings, actual.DNSSettings, "Embedded Account DNSSettings structs should be equal") + assert.Equal(t, expected.Onboarding, actual.Onboarding, "Embedded Account Onboarding structs should be equal") + + assert.Len(t, actual.SetupKeys, len(expected.SetupKeys), "SetupKeys maps should have the same number of elements") + for key, oldVal := range expected.SetupKeys { + newVal, ok := actual.SetupKeys[key] + assert.True(t, ok, "SetupKey with key '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "SetupKey with key '%s' should be equal", key) + } + + assert.Len(t, actual.Peers, len(expected.Peers), "Peers maps should have the same number of elements") + for key, oldVal := range expected.Peers { + newVal, ok := actual.Peers[key] + assert.True(t, ok, "Peer with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "Peer with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Users, len(expected.Users), "Users maps should have the same number of elements") + for key, oldUser := range expected.Users { + newUser, ok := actual.Users[key] + assert.True(t, ok, "User with ID '%s' should exist in new account", key) + + assert.Len(t, newUser.PATs, len(oldUser.PATs), "PATs map for user '%s' should have the same size", key) + for patKey, oldPAT := range oldUser.PATs { + newPAT, patOk := newUser.PATs[patKey] + assert.True(t, patOk, "PAT with ID '%s' for user '%s' should exist in new user object", patKey, key) + assert.Equal(t, *oldPAT, *newPAT, "PAT with ID '%s' for user '%s' should be equal", patKey, key) + } + + oldUser.PATs = nil + newUser.PATs = nil + assert.Equal(t, *oldUser, *newUser, "User struct for ID '%s' (without PATs) should be equal", key) + } + + assert.Len(t, actual.Groups, len(expected.Groups), "Groups maps should have the same number of elements") + for key, oldVal := range expected.Groups { + newVal, ok := actual.Groups[key] + assert.True(t, ok, "Group with ID '%s' should exist in new account", key) + sort.Strings(oldVal.Peers) + sort.Strings(newVal.Peers) + assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Routes, len(expected.Routes), "Routes maps should have the same number of elements") + for key, oldVal := range expected.Routes { + newVal, ok := actual.Routes[key] + assert.True(t, ok, "Route with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "Route with ID '%s' should be equal", key) + } + + assert.Len(t, actual.NameServerGroups, len(expected.NameServerGroups), "NameServerGroups maps should have the same number of elements") + for key, oldVal := range expected.NameServerGroups { + newVal, ok := actual.NameServerGroups[key] + assert.True(t, ok, "NameServerGroup with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "NameServerGroup with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Policies, len(expected.Policies), "Policies slices should have the same number of elements") + sort.Slice(expected.Policies, func(i, j int) bool { return expected.Policies[i].ID < expected.Policies[j].ID }) + sort.Slice(actual.Policies, func(i, j int) bool { return actual.Policies[i].ID < actual.Policies[j].ID }) + for i := range expected.Policies { + sort.Slice(expected.Policies[i].Rules, func(j, k int) bool { return expected.Policies[i].Rules[j].ID < expected.Policies[i].Rules[k].ID }) + sort.Slice(actual.Policies[i].Rules, func(j, k int) bool { return actual.Policies[i].Rules[j].ID < actual.Policies[i].Rules[k].ID }) + assert.Equal(t, *expected.Policies[i], *actual.Policies[i], "Policy with ID '%s' should be equal", expected.Policies[i].ID) + } + + assert.Len(t, actual.PostureChecks, len(expected.PostureChecks), "PostureChecks slices should have the same number of elements") + sort.Slice(expected.PostureChecks, func(i, j int) bool { return expected.PostureChecks[i].ID < expected.PostureChecks[j].ID }) + sort.Slice(actual.PostureChecks, func(i, j int) bool { return actual.PostureChecks[i].ID < actual.PostureChecks[j].ID }) + for i := range expected.PostureChecks { + assert.Equal(t, *expected.PostureChecks[i], *actual.PostureChecks[i], "PostureCheck with ID '%s' should be equal", expected.PostureChecks[i].ID) + } + + assert.Len(t, actual.Networks, len(expected.Networks), "Networks slices should have the same number of elements") + sort.Slice(expected.Networks, func(i, j int) bool { return expected.Networks[i].ID < expected.Networks[j].ID }) + sort.Slice(actual.Networks, func(i, j int) bool { return actual.Networks[i].ID < actual.Networks[j].ID }) + for i := range expected.Networks { + assert.Equal(t, *expected.Networks[i], *actual.Networks[i], "Network with ID '%s' should be equal", expected.Networks[i].ID) + } + + assert.Len(t, actual.NetworkRouters, len(expected.NetworkRouters), "NetworkRouters slices should have the same number of elements") + sort.Slice(expected.NetworkRouters, func(i, j int) bool { return expected.NetworkRouters[i].ID < expected.NetworkRouters[j].ID }) + sort.Slice(actual.NetworkRouters, func(i, j int) bool { return actual.NetworkRouters[i].ID < actual.NetworkRouters[j].ID }) + for i := range expected.NetworkRouters { + assert.Equal(t, *expected.NetworkRouters[i], *actual.NetworkRouters[i], "NetworkRouter with ID '%s' should be equal", expected.NetworkRouters[i].ID) + } + + assert.Len(t, actual.NetworkResources, len(expected.NetworkResources), "NetworkResources slices should have the same number of elements") + sort.Slice(expected.NetworkResources, func(i, j int) bool { return expected.NetworkResources[i].ID < expected.NetworkResources[j].ID }) + sort.Slice(actual.NetworkResources, func(i, j int) bool { return actual.NetworkResources[i].ID < actual.NetworkResources[j].ID }) + for i := range expected.NetworkResources { + assert.Equal(t, *expected.NetworkResources[i], *actual.NetworkResources[i], "NetworkResource with ID '%s' should be equal", expected.NetworkResources[i].ID) + } +} + +func normalizeNilSlices(acc *types.Account) { + if acc == nil { + return + } + + if acc.Policies == nil { + acc.Policies = []*types.Policy{} + } + if acc.PostureChecks == nil { + acc.PostureChecks = []*posture.Checks{} + } + if acc.Networks == nil { + acc.Networks = []*networkTypes.Network{} + } + if acc.NetworkRouters == nil { + acc.NetworkRouters = []*routerTypes.NetworkRouter{} + } + if acc.NetworkResources == nil { + acc.NetworkResources = []*resourceTypes.NetworkResource{} + } + if acc.DNSSettings.DisabledManagementGroups == nil { + acc.DNSSettings.DisabledManagementGroups = []string{} + } + + for _, key := range acc.SetupKeys { + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } + } + + for _, peer := range acc.Peers { + if peer.ExtraDNSLabels == nil { + peer.ExtraDNSLabels = []string{} + } + } + + for _, user := range acc.Users { + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } + } + + for _, group := range acc.Groups { + if group.Peers == nil { + group.Peers = []string{} + } + if group.Resources == nil { + group.Resources = []types.Resource{} + } + if group.GroupPeers == nil { + group.GroupPeers = []types.GroupPeer{} + } + } + + for _, route := range acc.Routes { + if route.Domains == nil { + route.Domains = domain.List{} + } + if route.PeerGroups == nil { + route.PeerGroups = []string{} + } + if route.Groups == nil { + route.Groups = []string{} + } + if route.AccessControlGroups == nil { + route.AccessControlGroups = []string{} + } + } + + for _, nsg := range acc.NameServerGroups { + if nsg.NameServers == nil { + nsg.NameServers = []nbdns.NameServer{} + } + if nsg.Groups == nil { + nsg.Groups = []string{} + } + if nsg.Domains == nil { + nsg.Domains = []string{} + } + } + + for _, policy := range acc.Policies { + if policy.SourcePostureChecks == nil { + policy.SourcePostureChecks = []string{} + } + if policy.Rules == nil { + policy.Rules = []*types.PolicyRule{} + } + for _, rule := range policy.Rules { + if rule.Destinations == nil { + rule.Destinations = []string{} + } + if rule.Sources == nil { + rule.Sources = []string{} + } + if rule.Ports == nil { + rule.Ports = []string{} + } + if rule.PortRanges == nil { + rule.PortRanges = []types.RulePortRange{} + } + } + } + + for _, check := range acc.PostureChecks { + if check.Checks.GeoLocationCheck != nil { + if check.Checks.GeoLocationCheck.Locations == nil { + check.Checks.GeoLocationCheck.Locations = []posture.Location{} + } + } + if check.Checks.PeerNetworkRangeCheck != nil { + if check.Checks.PeerNetworkRangeCheck.Ranges == nil { + check.Checks.PeerNetworkRangeCheck.Ranges = []netip.Prefix{} + } + } + if check.Checks.ProcessCheck != nil { + if check.Checks.ProcessCheck.Processes == nil { + check.Checks.ProcessCheck.Processes = []posture.Process{} + } + } + } + + for _, router := range acc.NetworkRouters { + if router.PeerGroups == nil { + router.PeerGroups = []string{} + } + } +} + +func TestGetAccountEquals(t *testing.T) { + store, accountID := setupBenchmarkDB(t) + ctx := context.Background() + expected, err := store.GetAccountSlow(ctx, accountID) + require.NoError(t, err) + actual, err := store.GetAccount(ctx, accountID) + require.NoError(t, err) + + require.Equal(t, expected.DNSSettings, actual.DNSSettings) + require.Equal(t, expected.Domain, actual.Domain) + require.Equal(t, expected.DomainCategory, actual.DomainCategory) + require.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount) + require.Equal(t, expected.Id, actual.Id) + require.Equal(t, expected.CreatedBy, actual.CreatedBy) + require.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second) + + require.Equal(t, len(expected.SetupKeys), len(actual.SetupKeys)) + for k, v := range expected.SetupKeys { + v2, ok := actual.SetupKeys[k] + require.True(t, ok) + require.Equal(t, v, v2) + } + + require.Equal(t, len(expected.Peers), len(actual.Peers)) + for k, v := range expected.Peers { + v2, ok := actual.Peers[k] + require.True(t, ok) + require.Equal(t, v, v2) + } + + require.Equal(t, len(expected.Users), len(actual.Users)) + for k, v := range expected.Users { + v2, ok := actual.Users[k] + require.True(t, ok) + require.Equal(t, v, v2) + require.Equal(t, len(v.PATs), len(v2.PATs)) + for k3, v3 := range v.PATs { + v4, ok := v2.PATs[k3] + require.True(t, ok) + require.Equal(t, v3, v4) + } + } + require.Equal(t, len(expected.Groups), len(actual.Groups)) + for k, v := range expected.Groups { + v2, ok := actual.Groups[k] + require.True(t, ok) + require.Equal(t, v, v2) + } + require.Equal(t, len(expected.Routes), len(actual.Routes)) + for k, v := range expected.Routes { + v2, ok := actual.Routes[k] + require.True(t, ok) + require.Equal(t, v, v2) + } + require.Equal(t, len(expected.NameServerGroups), len(actual.NameServerGroups)) + for k, v := range expected.NameServerGroups { + v2, ok := actual.NameServerGroups[k] + require.True(t, ok) + require.Equal(t, v, v2) + } + + require.Equal(t, expected.Policies, actual.Policies) + + require.Equal(t, expected.PostureChecks, actual.PostureChecks) + + require.Equal(t, expected.Network, actual.Network) + + require.Equal(t, expected.Networks, actual.Networks) + require.Equal(t, expected.NetworkRouters, actual.NetworkRouters) + require.Equal(t, expected.NetworkResources, actual.NetworkResources) + require.Equal(t, expected.Onboarding, actual.Onboarding) +} + +func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) { + var account types.Account + account.Network = &types.Network{} + const accountQuery = ` + SELECT + id, created_by, created_at, domain, domain_category, is_domain_primary_account, + network_identifier, network_net, network_dns, network_serial, + dns_settings_disabled_management_groups + FROM accounts WHERE id = $1` + + var networkNet, dnsSettingsDisabledGroups []byte + err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan( + &account.Id, &account.CreatedBy, &account.CreatedAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount, + &account.Network.Identifier, &networkNet, &account.Network.Dns, &account.Network.Serial, + &dnsSettingsDisabledGroups, + ) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, errors.New("account not found") + } + return nil, err + } + _ = json.Unmarshal(networkNet, &account.Network.Net) + _ = json.Unmarshal(dnsSettingsDisabledGroups, &account.DNSSettings.DisabledManagementGroups) + + var wg sync.WaitGroup + errChan := make(chan error, 12) + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + + keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { + var sk types.SetupKey + var autoGroups []byte + err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &sk.ExpiresAt, &sk.UpdatedAt, &sk.Revoked, &sk.UsedTimes, &sk.LastUsed, &autoGroups, &sk.UsageLimit, &sk.Ephemeral, &sk.AllowExtraDNSLabels) + if err == nil && autoGroups != nil { + _ = json.Unmarshal(autoGroups, &sk.AutoGroups) + } + if sk.UpdatedAt.IsZero() { + sk.UpdatedAt = sk.CreatedAt + } + return sk, err + }) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, location_geo_name_id FROM peers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + + peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { + var p nbpeer.Peer + p.Status = &nbpeer.PeerStatus{} + var ip, extraDNS, netAddr, env, flags, files, connIP []byte + err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &p.SSHEnabled, &p.LoginExpirationEnabled, &p.InactivityExpirationEnabled, &p.LastLogin, &p.CreatedAt, &p.Ephemeral, &extraDNS, &p.AllowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &p.Status.LastSeen, &p.Status.Connected, &p.Status.LoginExpired, &p.Status.RequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) + if err == nil { + if ip != nil { + _ = json.Unmarshal(ip, &p.IP) + } + if extraDNS != nil { + _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) + } + if netAddr != nil { + _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses) + } + if env != nil { + _ = json.Unmarshal(env, &p.Meta.Environment) + } + if flags != nil { + _ = json.Unmarshal(flags, &p.Meta.Flags) + } + if files != nil { + _ = json.Unmarshal(files, &p.Meta.Files) + } + if connIP != nil { + _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) + } + } + return p, err + }) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { + var u types.User + var autoGroups []byte + err := row.Scan(&u.Id, &u.AccountID, &u.Role, &u.IsServiceUser, &u.NonDeletable, &u.ServiceUserName, &autoGroups, &u.Blocked, &u.PendingApproval, &u.LastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) + if err == nil && autoGroups != nil { + _ = json.Unmarshal(autoGroups, &u.AutoGroups) + } + return u, err + }) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { + var g types.Group + var resources []byte + err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &g.IntegrationReference.ID, &g.IntegrationReference.IntegrationType) + if err == nil && resources != nil { + _ = json.Unmarshal(resources, &g.Resources) + } + return &g, err + }) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { + var p types.Policy + var checks []byte + err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &p.Enabled, &checks) + if err == nil && checks != nil { + _ = json.Unmarshal(checks, &p.SourcePostureChecks) + } + return &p, err + }) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { + var r route.Route + var network, domains, peerGroups, groups, accessGroups []byte + err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &r.KeepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &r.Masquerade, &r.Metric, &r.Enabled, &groups, &accessGroups, &r.SkipAutoApply) + if err == nil { + if network != nil { + _ = json.Unmarshal(network, &r.Network) + } + if domains != nil { + _ = json.Unmarshal(domains, &r.Domains) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + if groups != nil { + _ = json.Unmarshal(groups, &r.Groups) + } + if accessGroups != nil { + _ = json.Unmarshal(accessGroups, &r.AccessControlGroups) + } + } + return r, err + }) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { + var n nbdns.NameServerGroup + var ns, groups, domains []byte + err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &n.Primary, &domains, &n.Enabled, &n.SearchDomainsEnabled) + if err == nil { + if ns != nil { + _ = json.Unmarshal(ns, &n.NameServers) + } + if groups != nil { + _ = json.Unmarshal(groups, &n.Groups) + } + if domains != nil { + _ = json.Unmarshal(domains, &n.Domains) + } + } + return n, err + }) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { + var c posture.Checks + var checksDef []byte + err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) + if err == nil && checksDef != nil { + _ = json.Unmarshal(checksDef, &c.Checks) + } + return &c, err + }) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network]) + if err != nil { + errChan <- err + return + } + account.Networks = make([]*networkTypes.Network, len(networks)) + for i := range networks { + account.Networks[i] = &networks[i] + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*routerTypes.NetworkRouter, error) { + var r routerTypes.NetworkRouter + var peerGroups []byte + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &r.Masquerade, &r.Metric, &r.Enabled) + if err == nil && peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + return &r, err + }) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + errChan <- err + return + } + resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*resourceTypes.NetworkResource, error) { + var r resourceTypes.NetworkResource + var prefix []byte + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &r.Enabled) + if err == nil && prefix != nil { + _ = json.Unmarshal(prefix, &r.Prefix) + } + return &r, err + }) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` + err := s.pool.QueryRow(ctx, query, accountID).Scan( + &account.Onboarding.AccountID, + &account.Onboarding.OnboardingFlowPending, + &account.Onboarding.SignupFormPending, + &account.Onboarding.CreatedAt, + &account.Onboarding.UpdatedAt, + ) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + if len(userIDs) == 0 { + return + } + const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, userIDs) + if err != nil { + errChan <- err + return + } + pats, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.PersonalAccessToken]) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + if len(policyIDs) == 0 { + return + } + const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, policyIDs) + if err != nil { + errChan <- err + return + } + rules, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { + var r types.PolicyRule + var dest, destRes, sources, sourceRes, ports, portRanges []byte + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &r.Enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &r.Bidirectional, &r.Protocol, &ports, &portRanges) + if err == nil { + if dest != nil { + _ = json.Unmarshal(dest, &r.Destinations) + } + if destRes != nil { + _ = json.Unmarshal(destRes, &r.DestinationResource) + } + if sources != nil { + _ = json.Unmarshal(sources, &r.Sources) + } + if sourceRes != nil { + _ = json.Unmarshal(sourceRes, &r.SourceResource) + } + if ports != nil { + _ = json.Unmarshal(ports, &r.Ports) + } + if portRanges != nil { + _ = json.Unmarshal(portRanges, &r.PortRanges) + } + } + return &r, err + }) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + if len(groupIDs) == 0 { + return + } + const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, groupIDs) + if err != nil { + errChan <- err + return + } + groupPeers, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer]) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" + } + + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) + } + + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key + } + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat + } + } + account.Users[user.Id] = user + } + + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules + } + } + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs + } + account.Groups[group.ID] = group + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route + } + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg + } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil + account.NameServerGroupsG = nil + + return &account, nil +} From 682998a788b2c4f2c87779fc732ffba5c843c447 Mon Sep 17 00:00:00 2001 From: crn4 Date: Fri, 17 Oct 2025 14:34:22 +0200 Subject: [PATCH 03/14] nil slices to empty --- management/server/store/sql_store.go | 196 +++++++++++++++-- .../server/store/sqlstore_bench_test.go | 202 ++++++++++++++++-- 2 files changed, 369 insertions(+), 29 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index de1ea1a07..44d3a031c 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2,6 +2,7 @@ package store import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -782,15 +783,55 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc const accountQuery = ` SELECT id, created_by, created_at, domain, domain_category, is_domain_primary_account, + -- Embedded Network network_identifier, network_net, network_dns, network_serial, - dns_settings_disabled_management_groups + -- Embedded DNSSettings + dns_settings_disabled_management_groups, + -- Embedded Settings + settings_peer_login_expiration_enabled, settings_peer_login_expiration, + settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration, + settings_regular_users_view_blocked, settings_groups_propagation_enabled, + settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups, + settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range, + settings_lazy_connection_enabled, + -- Embedded ExtraSettings + settings_extra_peer_approval_enabled, settings_extra_user_approval_required, + settings_extra_integrated_validator, settings_extra_integrated_validator_groups FROM accounts WHERE id = $1` var networkNet, dnsSettingsDisabledGroups []byte + var ( + sPeerLoginExpirationEnabled sql.NullBool + sPeerLoginExpiration sql.NullInt64 + sPeerInactivityExpirationEnabled sql.NullBool + sPeerInactivityExpiration sql.NullInt64 + sRegularUsersViewBlocked sql.NullBool + sGroupsPropagationEnabled sql.NullBool + sJWTGroupsEnabled sql.NullBool + sJWTGroupsClaimName sql.NullString + sJWTAllowGroups []byte + sRoutingPeerDNSResolutionEnabled sql.NullBool + sDNSDomain sql.NullString + sNetworkRange []byte + sLazyConnectionEnabled sql.NullBool + sExtraPeerApprovalEnabled sql.NullBool + sExtraUserApprovalRequired sql.NullBool + sExtraIntegratedValidator sql.NullString + sExtraIntegratedValidatorGroups []byte + ) + err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan( &account.Id, &account.CreatedBy, &account.CreatedAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount, &account.Network.Identifier, &networkNet, &account.Network.Dns, &account.Network.Serial, &dnsSettingsDisabledGroups, + &sPeerLoginExpirationEnabled, &sPeerLoginExpiration, + &sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration, + &sRegularUsersViewBlocked, &sGroupsPropagationEnabled, + &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups, + &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange, + &sLazyConnectionEnabled, + &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired, + &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups, ) if err != nil { if errors.Is(err, pgx.ErrNoRows) { @@ -798,9 +839,64 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc } return nil, err } + _ = json.Unmarshal(networkNet, &account.Network.Net) _ = json.Unmarshal(dnsSettingsDisabledGroups, &account.DNSSettings.DisabledManagementGroups) + account.Settings = &types.Settings{Extra: &types.ExtraSettings{}} + if sPeerLoginExpirationEnabled.Valid { + account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool + } + if sPeerLoginExpiration.Valid { + account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64) + } + if sPeerInactivityExpirationEnabled.Valid { + account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool + } + if sPeerInactivityExpiration.Valid { + account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64) + } + if sRegularUsersViewBlocked.Valid { + account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool + } + if sGroupsPropagationEnabled.Valid { + account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool + } + if sJWTGroupsEnabled.Valid { + account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool + } + if sJWTGroupsClaimName.Valid { + account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String + } + if sRoutingPeerDNSResolutionEnabled.Valid { + account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool + } + if sDNSDomain.Valid { + account.Settings.DNSDomain = sDNSDomain.String + } + if sLazyConnectionEnabled.Valid { + account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool + } + if sJWTAllowGroups != nil { + _ = json.Unmarshal(sJWTAllowGroups, &account.Settings.JWTAllowGroups) + } + if sNetworkRange != nil { + _ = json.Unmarshal(sNetworkRange, &account.Settings.NetworkRange) + } + + if sExtraPeerApprovalEnabled.Valid { + account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool + } + if sExtraUserApprovalRequired.Valid { + account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool + } + if sExtraIntegratedValidator.Valid { + account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String + } + if sExtraIntegratedValidatorGroups != nil { + _ = json.Unmarshal(sExtraIntegratedValidatorGroups, &account.Settings.Extra.IntegratedValidatorGroups) + } + var wg sync.WaitGroup errChan := make(chan error, 12) @@ -817,12 +913,45 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { var sk types.SetupKey var autoGroups []byte - err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &sk.ExpiresAt, &sk.UpdatedAt, &sk.Revoked, &sk.UsedTimes, &sk.LastUsed, &autoGroups, &sk.UsageLimit, &sk.Ephemeral, &sk.AllowExtraDNSLabels) - if err == nil && autoGroups != nil { - _ = json.Unmarshal(autoGroups, &sk.AutoGroups) - } - if sk.UpdatedAt.IsZero() { - sk.UpdatedAt = sk.CreatedAt + var expiresAt, updatedAt, lastUsed sql.NullTime + var revoked, ephemeral, allowExtraDNSLabels sql.NullBool + var usedTimes, usageLimit sql.NullInt64 + + err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels) + + if err == nil { + if expiresAt.Valid { + sk.ExpiresAt = &expiresAt.Time + } + if updatedAt.Valid { + sk.UpdatedAt = updatedAt.Time + if sk.UpdatedAt.IsZero() { + sk.UpdatedAt = sk.CreatedAt + } + } + if lastUsed.Valid { + sk.LastUsed = &lastUsed.Time + } + if revoked.Valid { + sk.Revoked = revoked.Bool + } + if usedTimes.Valid { + sk.UsedTimes = int(usedTimes.Int64) + } + if usageLimit.Valid { + sk.UsageLimit = int(usageLimit.Int64) + } + if ephemeral.Valid { + sk.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &sk.AutoGroups) + } else { + sk.AutoGroups = []string{} + } } return sk, err }) @@ -892,9 +1021,30 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { var u types.User var autoGroups []byte - err := row.Scan(&u.Id, &u.AccountID, &u.Role, &u.IsServiceUser, &u.NonDeletable, &u.ServiceUserName, &autoGroups, &u.Blocked, &u.PendingApproval, &u.LastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) - if err == nil && autoGroups != nil { - _ = json.Unmarshal(autoGroups, &u.AutoGroups) + var lastLogin sql.NullTime + var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool + err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) + if err == nil { + if lastLogin.Valid { + u.LastLogin = &lastLogin.Time + } + if isServiceUser.Valid { + u.IsServiceUser = isServiceUser.Bool + } + if nonDeletable.Valid { + u.NonDeletable = nonDeletable.Bool + } + if blocked.Valid { + u.Blocked = blocked.Bool + } + if pendingApproval.Valid { + u.PendingApproval = pendingApproval.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &u.AutoGroups) + } else { + u.AutoGroups = []string{} + } } return u, err }) @@ -917,9 +1067,23 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { var g types.Group var resources []byte - err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &g.IntegrationReference.ID, &g.IntegrationReference.IntegrationType) - if err == nil && resources != nil { - _ = json.Unmarshal(resources, &g.Resources) + var refID sql.NullInt64 + var refType sql.NullString + err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) + if err == nil { + if refID.Valid { + g.IntegrationReference.ID = int(refID.Int64) + } + if refType.Valid { + g.IntegrationReference.IntegrationType = refType.String + } + if resources != nil { + _ = json.Unmarshal(resources, &g.Resources) + } else { + g.Resources = []types.Resource{} + } + g.GroupPeers = []types.GroupPeer{} + g.Peers = []string{} } return &g, err }) @@ -1010,12 +1174,18 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc if err == nil { if ns != nil { _ = json.Unmarshal(ns, &n.NameServers) + } else { + n.NameServers = []nbdns.NameServer{} } if groups != nil { _ = json.Unmarshal(groups, &n.Groups) + } else { + n.Groups = []string{} } if domains != nil { _ = json.Unmarshal(domains, &n.Domains) + } else { + n.Domains = []string{} } } return n, err diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index add3fffe5..f7720df89 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -2,6 +2,7 @@ package store import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -434,7 +435,7 @@ func TestAccountEquivalence(t *testing.T) { expectedF getAccountFunc actualF getAccountFunc }{ - {"old vs new", store.GetAccountSlow, store.GetAccount}, + // {"old vs new", store.GetAccountSlow, store.GetAccount}, {"old vs raw", store.GetAccountSlow, store.GetAccountPureSQL}, } @@ -461,8 +462,8 @@ func TestAccountEquivalence(t *testing.T) { } func testAccountEquivalence(t *testing.T, expected, actual *types.Account) { - normalizeNilSlices(expected) - normalizeNilSlices(actual) + // normalizeNilSlices(expected) + // normalizeNilSlices(actual) assert.Equal(t, expected.Id, actual.Id, "Account IDs should be equal") assert.Equal(t, expected.CreatedBy, actual.CreatedBy, "Account CreatedBy fields should be equal") @@ -774,15 +775,55 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty const accountQuery = ` SELECT id, created_by, created_at, domain, domain_category, is_domain_primary_account, + -- Embedded Network network_identifier, network_net, network_dns, network_serial, - dns_settings_disabled_management_groups + -- Embedded DNSSettings + dns_settings_disabled_management_groups, + -- Embedded Settings + settings_peer_login_expiration_enabled, settings_peer_login_expiration, + settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration, + settings_regular_users_view_blocked, settings_groups_propagation_enabled, + settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups, + settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range, + settings_lazy_connection_enabled, + -- Embedded ExtraSettings + settings_extra_peer_approval_enabled, settings_extra_user_approval_required, + settings_extra_integrated_validator, settings_extra_integrated_validator_groups FROM accounts WHERE id = $1` var networkNet, dnsSettingsDisabledGroups []byte + var ( + sPeerLoginExpirationEnabled sql.NullBool + sPeerLoginExpiration sql.NullInt64 + sPeerInactivityExpirationEnabled sql.NullBool + sPeerInactivityExpiration sql.NullInt64 + sRegularUsersViewBlocked sql.NullBool + sGroupsPropagationEnabled sql.NullBool + sJWTGroupsEnabled sql.NullBool + sJWTGroupsClaimName sql.NullString + sJWTAllowGroups []byte + sRoutingPeerDNSResolutionEnabled sql.NullBool + sDNSDomain sql.NullString + sNetworkRange []byte + sLazyConnectionEnabled sql.NullBool + sExtraPeerApprovalEnabled sql.NullBool + sExtraUserApprovalRequired sql.NullBool + sExtraIntegratedValidator sql.NullString + sExtraIntegratedValidatorGroups []byte + ) + err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan( &account.Id, &account.CreatedBy, &account.CreatedAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount, &account.Network.Identifier, &networkNet, &account.Network.Dns, &account.Network.Serial, &dnsSettingsDisabledGroups, + &sPeerLoginExpirationEnabled, &sPeerLoginExpiration, + &sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration, + &sRegularUsersViewBlocked, &sGroupsPropagationEnabled, + &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups, + &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange, + &sLazyConnectionEnabled, + &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired, + &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups, ) if err != nil { if errors.Is(err, pgx.ErrNoRows) { @@ -790,9 +831,64 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty } return nil, err } + _ = json.Unmarshal(networkNet, &account.Network.Net) _ = json.Unmarshal(dnsSettingsDisabledGroups, &account.DNSSettings.DisabledManagementGroups) + account.Settings = &types.Settings{Extra: &types.ExtraSettings{}} + if sPeerLoginExpirationEnabled.Valid { + account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool + } + if sPeerLoginExpiration.Valid { + account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64) + } + if sPeerInactivityExpirationEnabled.Valid { + account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool + } + if sPeerInactivityExpiration.Valid { + account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64) + } + if sRegularUsersViewBlocked.Valid { + account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool + } + if sGroupsPropagationEnabled.Valid { + account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool + } + if sJWTGroupsEnabled.Valid { + account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool + } + if sJWTGroupsClaimName.Valid { + account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String + } + if sRoutingPeerDNSResolutionEnabled.Valid { + account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool + } + if sDNSDomain.Valid { + account.Settings.DNSDomain = sDNSDomain.String + } + if sLazyConnectionEnabled.Valid { + account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool + } + if sJWTAllowGroups != nil { + _ = json.Unmarshal(sJWTAllowGroups, &account.Settings.JWTAllowGroups) + } + if sNetworkRange != nil { + _ = json.Unmarshal(sNetworkRange, &account.Settings.NetworkRange) + } + + if sExtraPeerApprovalEnabled.Valid { + account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool + } + if sExtraUserApprovalRequired.Valid { + account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool + } + if sExtraIntegratedValidator.Valid { + account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String + } + if sExtraIntegratedValidatorGroups != nil { + _ = json.Unmarshal(sExtraIntegratedValidatorGroups, &account.Settings.Extra.IntegratedValidatorGroups) + } + var wg sync.WaitGroup errChan := make(chan error, 12) @@ -809,12 +905,45 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { var sk types.SetupKey var autoGroups []byte - err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &sk.ExpiresAt, &sk.UpdatedAt, &sk.Revoked, &sk.UsedTimes, &sk.LastUsed, &autoGroups, &sk.UsageLimit, &sk.Ephemeral, &sk.AllowExtraDNSLabels) - if err == nil && autoGroups != nil { - _ = json.Unmarshal(autoGroups, &sk.AutoGroups) - } - if sk.UpdatedAt.IsZero() { - sk.UpdatedAt = sk.CreatedAt + var expiresAt, updatedAt, lastUsed sql.NullTime + var revoked, ephemeral, allowExtraDNSLabels sql.NullBool + var usedTimes, usageLimit sql.NullInt64 + + err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels) + + if err == nil { + if expiresAt.Valid { + sk.ExpiresAt = &expiresAt.Time + } + if updatedAt.Valid { + sk.UpdatedAt = updatedAt.Time + if sk.UpdatedAt.IsZero() { + sk.UpdatedAt = sk.CreatedAt + } + } + if lastUsed.Valid { + sk.LastUsed = &lastUsed.Time + } + if revoked.Valid { + sk.Revoked = revoked.Bool + } + if usedTimes.Valid { + sk.UsedTimes = int(usedTimes.Int64) + } + if usageLimit.Valid { + sk.UsageLimit = int(usageLimit.Int64) + } + if ephemeral.Valid { + sk.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &sk.AutoGroups) + } else { + sk.AutoGroups = []string{} + } } return sk, err }) @@ -884,9 +1013,30 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { var u types.User var autoGroups []byte - err := row.Scan(&u.Id, &u.AccountID, &u.Role, &u.IsServiceUser, &u.NonDeletable, &u.ServiceUserName, &autoGroups, &u.Blocked, &u.PendingApproval, &u.LastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) - if err == nil && autoGroups != nil { - _ = json.Unmarshal(autoGroups, &u.AutoGroups) + var lastLogin sql.NullTime + var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool + err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) + if err == nil { + if lastLogin.Valid { + u.LastLogin = &lastLogin.Time + } + if isServiceUser.Valid { + u.IsServiceUser = isServiceUser.Bool + } + if nonDeletable.Valid { + u.NonDeletable = nonDeletable.Bool + } + if blocked.Valid { + u.Blocked = blocked.Bool + } + if pendingApproval.Valid { + u.PendingApproval = pendingApproval.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &u.AutoGroups) + } else { + u.AutoGroups = []string{} + } } return u, err }) @@ -909,9 +1059,23 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { var g types.Group var resources []byte - err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &g.IntegrationReference.ID, &g.IntegrationReference.IntegrationType) - if err == nil && resources != nil { - _ = json.Unmarshal(resources, &g.Resources) + var refID sql.NullInt64 + var refType sql.NullString + err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) + if err == nil { + if refID.Valid { + g.IntegrationReference.ID = int(refID.Int64) + } + if refType.Valid { + g.IntegrationReference.IntegrationType = refType.String + } + if resources != nil { + _ = json.Unmarshal(resources, &g.Resources) + } else { + g.Resources = []types.Resource{} + } + g.GroupPeers = []types.GroupPeer{} + g.Peers = []string{} } return &g, err }) @@ -1002,12 +1166,18 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty if err == nil { if ns != nil { _ = json.Unmarshal(ns, &n.NameServers) + } else { + n.NameServers = []nbdns.NameServer{} } if groups != nil { _ = json.Unmarshal(groups, &n.Groups) + } else { + n.Groups = []string{} } if domains != nil { _ = json.Unmarshal(domains, &n.Domains) + } else { + n.Domains = []string{} } } return n, err From 7859d66e3408216b504822700ca4fd00abc82da7 Mon Sep 17 00:00:00 2001 From: crn4 Date: Fri, 17 Oct 2025 10:19:29 +0200 Subject: [PATCH 04/14] go mod tidy --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 31b45e881..cb88f92d3 100644 --- a/go.mod +++ b/go.mod @@ -56,6 +56,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 + github.com/jackc/pgx/v5 v5.5.5 github.com/libdns/route53 v1.5.0 github.com/libp2p/go-netroute v0.2.1 github.com/mdlayher/socket v0.5.1 @@ -183,7 +184,6 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.5.5 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect github.com/jinzhu/inflection v1.0.0 // indirect From 8e3f0090f0b2cbea7612f41dc792ed813e9aee49 Mon Sep 17 00:00:00 2001 From: crn4 Date: Fri, 17 Oct 2025 11:03:04 +0200 Subject: [PATCH 05/14] null bools --- .../server/store/sqlstore_bench_test.go | 114 ++++++------------ 1 file changed, 38 insertions(+), 76 deletions(-) diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index f7720df89..dc5754208 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -21,7 +21,6 @@ import ( "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -696,79 +695,6 @@ func normalizeNilSlices(acc *types.Account) { } } -func TestGetAccountEquals(t *testing.T) { - store, accountID := setupBenchmarkDB(t) - ctx := context.Background() - expected, err := store.GetAccountSlow(ctx, accountID) - require.NoError(t, err) - actual, err := store.GetAccount(ctx, accountID) - require.NoError(t, err) - - require.Equal(t, expected.DNSSettings, actual.DNSSettings) - require.Equal(t, expected.Domain, actual.Domain) - require.Equal(t, expected.DomainCategory, actual.DomainCategory) - require.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount) - require.Equal(t, expected.Id, actual.Id) - require.Equal(t, expected.CreatedBy, actual.CreatedBy) - require.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second) - - require.Equal(t, len(expected.SetupKeys), len(actual.SetupKeys)) - for k, v := range expected.SetupKeys { - v2, ok := actual.SetupKeys[k] - require.True(t, ok) - require.Equal(t, v, v2) - } - - require.Equal(t, len(expected.Peers), len(actual.Peers)) - for k, v := range expected.Peers { - v2, ok := actual.Peers[k] - require.True(t, ok) - require.Equal(t, v, v2) - } - - require.Equal(t, len(expected.Users), len(actual.Users)) - for k, v := range expected.Users { - v2, ok := actual.Users[k] - require.True(t, ok) - require.Equal(t, v, v2) - require.Equal(t, len(v.PATs), len(v2.PATs)) - for k3, v3 := range v.PATs { - v4, ok := v2.PATs[k3] - require.True(t, ok) - require.Equal(t, v3, v4) - } - } - require.Equal(t, len(expected.Groups), len(actual.Groups)) - for k, v := range expected.Groups { - v2, ok := actual.Groups[k] - require.True(t, ok) - require.Equal(t, v, v2) - } - require.Equal(t, len(expected.Routes), len(actual.Routes)) - for k, v := range expected.Routes { - v2, ok := actual.Routes[k] - require.True(t, ok) - require.Equal(t, v, v2) - } - require.Equal(t, len(expected.NameServerGroups), len(actual.NameServerGroups)) - for k, v := range expected.NameServerGroups { - v2, ok := actual.NameServerGroups[k] - require.True(t, ok) - require.Equal(t, v, v2) - } - - require.Equal(t, expected.Policies, actual.Policies) - - require.Equal(t, expected.PostureChecks, actual.PostureChecks) - - require.Equal(t, expected.Network, actual.Network) - - require.Equal(t, expected.Networks, actual.Networks) - require.Equal(t, expected.NetworkRouters, actual.NetworkRouters) - require.Equal(t, expected.NetworkResources, actual.NetworkResources) - require.Equal(t, expected.Onboarding, actual.Onboarding) -} - func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) { var account types.Account account.Network = &types.Network{} @@ -966,10 +892,46 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { var p nbpeer.Peer - p.Status = &nbpeer.PeerStatus{} + var lastLogin sql.NullTime + var sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool + var peerStatusLastSeen sql.NullTime + var peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool var ip, extraDNS, netAddr, env, flags, files, connIP []byte - err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &p.SSHEnabled, &p.LoginExpirationEnabled, &p.InactivityExpirationEnabled, &p.LastLogin, &p.CreatedAt, &p.Ephemeral, &extraDNS, &p.AllowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &p.Status.LastSeen, &p.Status.Connected, &p.Status.LoginExpired, &p.Status.RequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) + + err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &p.CreatedAt, &ephemeral, &extraDNS, &allowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) + if err == nil { + if lastLogin.Valid { + p.LastLogin = &lastLogin.Time + } + if sshEnabled.Valid { + p.SSHEnabled = sshEnabled.Bool + } + if loginExpirationEnabled.Valid { + p.LoginExpirationEnabled = loginExpirationEnabled.Bool + } + if inactivityExpirationEnabled.Valid { + p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool + } + if ephemeral.Valid { + p.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if peerStatusLastSeen.Valid { + p.Status.LastSeen = peerStatusLastSeen.Time + } + if peerStatusConnected.Valid { + p.Status.Connected = peerStatusConnected.Bool + } + if peerStatusLoginExpired.Valid { + p.Status.LoginExpired = peerStatusLoginExpired.Bool + } + if peerStatusRequiresApproval.Valid { + p.Status.RequiresApproval = peerStatusRequiresApproval.Bool + } + if ip != nil { _ = json.Unmarshal(ip, &p.IP) } From af29a18a109deb3c1ca03285e642b456b4adb998 Mon Sep 17 00:00:00 2001 From: crn4 Date: Fri, 17 Oct 2025 11:25:15 +0200 Subject: [PATCH 06/14] more null fields --- .../server/store/sqlstore_bench_test.go | 121 ++++++++++++++---- 1 file changed, 98 insertions(+), 23 deletions(-) diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index dc5754208..f446ffe3a 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -867,8 +867,6 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty } if autoGroups != nil { _ = json.Unmarshal(autoGroups, &sk.AutoGroups) - } else { - sk.AutoGroups = []string{} } } return sk, err @@ -892,6 +890,7 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { var p nbpeer.Peer + p.Status = &nbpeer.PeerStatus{} var lastLogin sql.NullTime var sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool var peerStatusLastSeen sql.NullTime @@ -996,8 +995,6 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty } if autoGroups != nil { _ = json.Unmarshal(autoGroups, &u.AutoGroups) - } else { - u.AutoGroups = []string{} } } return u, err @@ -1060,9 +1057,15 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { var p types.Policy var checks []byte - err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &p.Enabled, &checks) - if err == nil && checks != nil { - _ = json.Unmarshal(checks, &p.SourcePostureChecks) + var enabled sql.NullBool + err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) + if err == nil { + if enabled.Valid { + p.Enabled = enabled.Bool + } + if checks != nil { + _ = json.Unmarshal(checks, &p.SourcePostureChecks) + } } return &p, err }) @@ -1085,8 +1088,25 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { var r route.Route var network, domains, peerGroups, groups, accessGroups []byte - err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &r.KeepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &r.Masquerade, &r.Metric, &r.Enabled, &groups, &accessGroups, &r.SkipAutoApply) + var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) if err == nil { + if keepRoute.Valid { + r.KeepRoute = keepRoute.Bool + } + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if skipAutoApply.Valid { + r.SkipAutoApply = skipAutoApply.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } if network != nil { _ = json.Unmarshal(network, &r.Network) } @@ -1124,8 +1144,18 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { var n nbdns.NameServerGroup var ns, groups, domains []byte - err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &n.Primary, &domains, &n.Enabled, &n.SearchDomainsEnabled) + var primary, enabled, searchDomainsEnabled sql.NullBool + err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) if err == nil { + if primary.Valid { + n.Primary = primary.Bool + } + if enabled.Valid { + n.Enabled = enabled.Bool + } + if searchDomainsEnabled.Valid { + n.SearchDomainsEnabled = searchDomainsEnabled.Bool + } if ns != nil { _ = json.Unmarshal(ns, &n.NameServers) } else { @@ -1205,20 +1235,36 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty errChan <- err return } - routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*routerTypes.NetworkRouter, error) { + routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { var r routerTypes.NetworkRouter var peerGroups []byte - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &r.Masquerade, &r.Metric, &r.Enabled) - if err == nil && peerGroups != nil { - _ = json.Unmarshal(peerGroups, &r.PeerGroups) + var masquerade, enabled sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) + if err == nil { + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } } - return &r, err + return r, err }) if err != nil { errChan <- err return } - account.NetworkRouters = routers + account.NetworkRouters = make([]*routerTypes.NetworkRouter, len(routers)) + for i := range routers { + account.NetworkRouters[i] = &routers[i] + } }() wg.Add(1) @@ -1230,20 +1276,29 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty errChan <- err return } - resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*resourceTypes.NetworkResource, error) { + resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { var r resourceTypes.NetworkResource var prefix []byte - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &r.Enabled) - if err == nil && prefix != nil { - _ = json.Unmarshal(prefix, &r.Prefix) + var enabled sql.NullBool + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if prefix != nil { + _ = json.Unmarshal(prefix, &r.Prefix) + } } - return &r, err + return r, err }) if err != nil { errChan <- err return } - account.NetworkResources = resources + account.NetworkResources = make([]*resourceTypes.NetworkResource, len(resources)) + for i := range resources { + account.NetworkResources[i] = &resources[i] + } }() wg.Add(1) @@ -1298,7 +1353,20 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty errChan <- err return } - pats, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.PersonalAccessToken]) + pats, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { + var pat types.PersonalAccessToken + var expirationDate, lastUsed sql.NullTime + err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &pat.CreatedAt, &lastUsed) + if err == nil { + if expirationDate.Valid { + pat.ExpirationDate = &expirationDate.Time + } + if lastUsed.Valid { + pat.LastUsed = &lastUsed.Time + } + } + return pat, err + }) if err != nil { errChan <- err } @@ -1319,8 +1387,15 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty rules, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { var r types.PolicyRule var dest, destRes, sources, sourceRes, ports, portRanges []byte - err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &r.Enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &r.Bidirectional, &r.Protocol, &ports, &portRanges) + var enabled, bidirectional sql.NullBool + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if bidirectional.Valid { + r.Bidirectional = bidirectional.Bool + } if dest != nil { _ = json.Unmarshal(dest, &r.Destinations) } From 5320c89bddcc45cbb8ec8b58f50ed11558dcd0cc Mon Sep 17 00:00:00 2001 From: crn4 Date: Fri, 17 Oct 2025 11:44:25 +0200 Subject: [PATCH 07/14] more nullable fields --- management/server/store/sqlstore_bench_test.go | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index f446ffe3a..711b80b72 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -1030,11 +1030,7 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty } if resources != nil { _ = json.Unmarshal(resources, &g.Resources) - } else { - g.Resources = []types.Resource{} } - g.GroupPeers = []types.GroupPeer{} - g.Peers = []string{} } return &g, err }) @@ -1305,15 +1301,23 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty go func() { defer wg.Done() const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` + var onboardingFlowPending, signupFormPending sql.NullBool err := s.pool.QueryRow(ctx, query, accountID).Scan( &account.Onboarding.AccountID, - &account.Onboarding.OnboardingFlowPending, - &account.Onboarding.SignupFormPending, + &onboardingFlowPending, + &signupFormPending, &account.Onboarding.CreatedAt, &account.Onboarding.UpdatedAt, ) if err != nil && !errors.Is(err, pgx.ErrNoRows) { errChan <- err + return + } + if onboardingFlowPending.Valid { + account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool + } + if signupFormPending.Valid { + account.Onboarding.SignupFormPending = signupFormPending.Bool } }() From f588997c496267843a749198544137ea9339e167 Mon Sep 17 00:00:00 2001 From: crn4 Date: Fri, 17 Oct 2025 11:53:58 +0200 Subject: [PATCH 08/14] change main get account method --- management/server/store/sql_store.go | 175 ++++++++++++++++++++++----- 1 file changed, 145 insertions(+), 30 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 44d3a031c..7124486ad 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -949,8 +949,6 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc } if autoGroups != nil { _ = json.Unmarshal(autoGroups, &sk.AutoGroups) - } else { - sk.AutoGroups = []string{} } } return sk, err @@ -975,9 +973,46 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { var p nbpeer.Peer p.Status = &nbpeer.PeerStatus{} + var lastLogin sql.NullTime + var sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool + var peerStatusLastSeen sql.NullTime + var peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool var ip, extraDNS, netAddr, env, flags, files, connIP []byte - err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &p.SSHEnabled, &p.LoginExpirationEnabled, &p.InactivityExpirationEnabled, &p.LastLogin, &p.CreatedAt, &p.Ephemeral, &extraDNS, &p.AllowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &p.Status.LastSeen, &p.Status.Connected, &p.Status.LoginExpired, &p.Status.RequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) + + err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &p.CreatedAt, &ephemeral, &extraDNS, &allowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) + if err == nil { + if lastLogin.Valid { + p.LastLogin = &lastLogin.Time + } + if sshEnabled.Valid { + p.SSHEnabled = sshEnabled.Bool + } + if loginExpirationEnabled.Valid { + p.LoginExpirationEnabled = loginExpirationEnabled.Bool + } + if inactivityExpirationEnabled.Valid { + p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool + } + if ephemeral.Valid { + p.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if peerStatusLastSeen.Valid { + p.Status.LastSeen = peerStatusLastSeen.Time + } + if peerStatusConnected.Valid { + p.Status.Connected = peerStatusConnected.Bool + } + if peerStatusLoginExpired.Valid { + p.Status.LoginExpired = peerStatusLoginExpired.Bool + } + if peerStatusRequiresApproval.Valid { + p.Status.RequiresApproval = peerStatusRequiresApproval.Bool + } + if ip != nil { _ = json.Unmarshal(ip, &p.IP) } @@ -1042,8 +1077,6 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc } if autoGroups != nil { _ = json.Unmarshal(autoGroups, &u.AutoGroups) - } else { - u.AutoGroups = []string{} } } return u, err @@ -1079,11 +1112,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc } if resources != nil { _ = json.Unmarshal(resources, &g.Resources) - } else { - g.Resources = []types.Resource{} } - g.GroupPeers = []types.GroupPeer{} - g.Peers = []string{} } return &g, err }) @@ -1106,9 +1135,15 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { var p types.Policy var checks []byte - err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &p.Enabled, &checks) - if err == nil && checks != nil { - _ = json.Unmarshal(checks, &p.SourcePostureChecks) + var enabled sql.NullBool + err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) + if err == nil { + if enabled.Valid { + p.Enabled = enabled.Bool + } + if checks != nil { + _ = json.Unmarshal(checks, &p.SourcePostureChecks) + } } return &p, err }) @@ -1131,8 +1166,25 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { var r route.Route var network, domains, peerGroups, groups, accessGroups []byte - err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &r.KeepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &r.Masquerade, &r.Metric, &r.Enabled, &groups, &accessGroups, &r.SkipAutoApply) + var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) if err == nil { + if keepRoute.Valid { + r.KeepRoute = keepRoute.Bool + } + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if skipAutoApply.Valid { + r.SkipAutoApply = skipAutoApply.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } if network != nil { _ = json.Unmarshal(network, &r.Network) } @@ -1170,8 +1222,18 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { var n nbdns.NameServerGroup var ns, groups, domains []byte - err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &n.Primary, &domains, &n.Enabled, &n.SearchDomainsEnabled) + var primary, enabled, searchDomainsEnabled sql.NullBool + err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) if err == nil { + if primary.Valid { + n.Primary = primary.Bool + } + if enabled.Valid { + n.Enabled = enabled.Bool + } + if searchDomainsEnabled.Valid { + n.SearchDomainsEnabled = searchDomainsEnabled.Bool + } if ns != nil { _ = json.Unmarshal(ns, &n.NameServers) } else { @@ -1251,20 +1313,36 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc errChan <- err return } - routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*routerTypes.NetworkRouter, error) { + routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { var r routerTypes.NetworkRouter var peerGroups []byte - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &r.Masquerade, &r.Metric, &r.Enabled) - if err == nil && peerGroups != nil { - _ = json.Unmarshal(peerGroups, &r.PeerGroups) + var masquerade, enabled sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) + if err == nil { + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } } - return &r, err + return r, err }) if err != nil { errChan <- err return } - account.NetworkRouters = routers + account.NetworkRouters = make([]*routerTypes.NetworkRouter, len(routers)) + for i := range routers { + account.NetworkRouters[i] = &routers[i] + } }() wg.Add(1) @@ -1276,35 +1354,52 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc errChan <- err return } - resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*resourceTypes.NetworkResource, error) { + resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { var r resourceTypes.NetworkResource var prefix []byte - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &r.Enabled) - if err == nil && prefix != nil { - _ = json.Unmarshal(prefix, &r.Prefix) + var enabled sql.NullBool + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if prefix != nil { + _ = json.Unmarshal(prefix, &r.Prefix) + } } - return &r, err + return r, err }) if err != nil { errChan <- err return } - account.NetworkResources = resources + account.NetworkResources = make([]*resourceTypes.NetworkResource, len(resources)) + for i := range resources { + account.NetworkResources[i] = &resources[i] + } }() wg.Add(1) go func() { defer wg.Done() const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` + var onboardingFlowPending, signupFormPending sql.NullBool err := s.pool.QueryRow(ctx, query, accountID).Scan( &account.Onboarding.AccountID, - &account.Onboarding.OnboardingFlowPending, - &account.Onboarding.SignupFormPending, + &onboardingFlowPending, + &signupFormPending, &account.Onboarding.CreatedAt, &account.Onboarding.UpdatedAt, ) if err != nil && !errors.Is(err, pgx.ErrNoRows) { errChan <- err + return + } + if onboardingFlowPending.Valid { + account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool + } + if signupFormPending.Valid { + account.Onboarding.SignupFormPending = signupFormPending.Bool } }() @@ -1344,7 +1439,20 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc errChan <- err return } - pats, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.PersonalAccessToken]) + pats, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { + var pat types.PersonalAccessToken + var expirationDate, lastUsed sql.NullTime + err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &pat.CreatedAt, &lastUsed) + if err == nil { + if expirationDate.Valid { + pat.ExpirationDate = &expirationDate.Time + } + if lastUsed.Valid { + pat.LastUsed = &lastUsed.Time + } + } + return pat, err + }) if err != nil { errChan <- err } @@ -1365,8 +1473,15 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc rules, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { var r types.PolicyRule var dest, destRes, sources, sourceRes, ports, portRanges []byte - err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &r.Enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &r.Bidirectional, &r.Protocol, &ports, &portRanges) + var enabled, bidirectional sql.NullBool + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if bidirectional.Valid { + r.Bidirectional = bidirectional.Bool + } if dest != nil { _ = json.Unmarshal(dest, &r.Destinations) } From d68eb8cc930c97dfb8b350b15362eefaad177d9a Mon Sep 17 00:00:00 2001 From: crn4 Date: Fri, 17 Oct 2025 14:34:22 +0200 Subject: [PATCH 09/14] nil slices to empty --- management/server/store/sql_store.go | 8 ++++++++ management/server/store/sqlstore_bench_test.go | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 7124486ad..865051fec 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -949,6 +949,8 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc } if autoGroups != nil { _ = json.Unmarshal(autoGroups, &sk.AutoGroups) + } else { + sk.AutoGroups = []string{} } } return sk, err @@ -1077,6 +1079,8 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc } if autoGroups != nil { _ = json.Unmarshal(autoGroups, &u.AutoGroups) + } else { + u.AutoGroups = []string{} } } return u, err @@ -1112,7 +1116,11 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc } if resources != nil { _ = json.Unmarshal(resources, &g.Resources) + } else { + g.Resources = []types.Resource{} } + g.GroupPeers = []types.GroupPeer{} + g.Peers = []string{} } return &g, err }) diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index 711b80b72..2b0256bd8 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -867,6 +867,8 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty } if autoGroups != nil { _ = json.Unmarshal(autoGroups, &sk.AutoGroups) + } else { + sk.AutoGroups = []string{} } } return sk, err @@ -995,6 +997,8 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty } if autoGroups != nil { _ = json.Unmarshal(autoGroups, &u.AutoGroups) + } else { + u.AutoGroups = []string{} } } return u, err @@ -1030,7 +1034,11 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty } if resources != nil { _ = json.Unmarshal(resources, &g.Resources) + } else { + g.Resources = []types.Resource{} } + g.GroupPeers = []types.GroupPeer{} + g.Peers = []string{} } return &g, err }) From feb14c4e54d72628738d30eb4facafd4524c3f7d Mon Sep 17 00:00:00 2001 From: crn4 Date: Sun, 19 Oct 2025 17:41:51 +0200 Subject: [PATCH 10/14] code cleanup --- management/server/store/sql_store.go | 894 +-------- .../server/store/sqlstore_bench_test.go | 1651 ++++++++--------- 2 files changed, 840 insertions(+), 1705 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 865051fec..41d2af0c2 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2,7 +2,6 @@ package store import ( "context" - "database/sql" "encoding/json" "errors" "fmt" @@ -16,7 +15,6 @@ import ( "sync" "time" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" "gorm.io/driver/mysql" @@ -778,852 +776,6 @@ func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types. } func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { - var account types.Account - account.Network = &types.Network{} - const accountQuery = ` - SELECT - id, created_by, created_at, domain, domain_category, is_domain_primary_account, - -- Embedded Network - network_identifier, network_net, network_dns, network_serial, - -- Embedded DNSSettings - dns_settings_disabled_management_groups, - -- Embedded Settings - settings_peer_login_expiration_enabled, settings_peer_login_expiration, - settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration, - settings_regular_users_view_blocked, settings_groups_propagation_enabled, - settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups, - settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range, - settings_lazy_connection_enabled, - -- Embedded ExtraSettings - settings_extra_peer_approval_enabled, settings_extra_user_approval_required, - settings_extra_integrated_validator, settings_extra_integrated_validator_groups - FROM accounts WHERE id = $1` - - var networkNet, dnsSettingsDisabledGroups []byte - var ( - sPeerLoginExpirationEnabled sql.NullBool - sPeerLoginExpiration sql.NullInt64 - sPeerInactivityExpirationEnabled sql.NullBool - sPeerInactivityExpiration sql.NullInt64 - sRegularUsersViewBlocked sql.NullBool - sGroupsPropagationEnabled sql.NullBool - sJWTGroupsEnabled sql.NullBool - sJWTGroupsClaimName sql.NullString - sJWTAllowGroups []byte - sRoutingPeerDNSResolutionEnabled sql.NullBool - sDNSDomain sql.NullString - sNetworkRange []byte - sLazyConnectionEnabled sql.NullBool - sExtraPeerApprovalEnabled sql.NullBool - sExtraUserApprovalRequired sql.NullBool - sExtraIntegratedValidator sql.NullString - sExtraIntegratedValidatorGroups []byte - ) - - err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan( - &account.Id, &account.CreatedBy, &account.CreatedAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount, - &account.Network.Identifier, &networkNet, &account.Network.Dns, &account.Network.Serial, - &dnsSettingsDisabledGroups, - &sPeerLoginExpirationEnabled, &sPeerLoginExpiration, - &sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration, - &sRegularUsersViewBlocked, &sGroupsPropagationEnabled, - &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups, - &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange, - &sLazyConnectionEnabled, - &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired, - &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups, - ) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return nil, errors.New("account not found") - } - return nil, err - } - - _ = json.Unmarshal(networkNet, &account.Network.Net) - _ = json.Unmarshal(dnsSettingsDisabledGroups, &account.DNSSettings.DisabledManagementGroups) - - account.Settings = &types.Settings{Extra: &types.ExtraSettings{}} - if sPeerLoginExpirationEnabled.Valid { - account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool - } - if sPeerLoginExpiration.Valid { - account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64) - } - if sPeerInactivityExpirationEnabled.Valid { - account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool - } - if sPeerInactivityExpiration.Valid { - account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64) - } - if sRegularUsersViewBlocked.Valid { - account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool - } - if sGroupsPropagationEnabled.Valid { - account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool - } - if sJWTGroupsEnabled.Valid { - account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool - } - if sJWTGroupsClaimName.Valid { - account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String - } - if sRoutingPeerDNSResolutionEnabled.Valid { - account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool - } - if sDNSDomain.Valid { - account.Settings.DNSDomain = sDNSDomain.String - } - if sLazyConnectionEnabled.Valid { - account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool - } - if sJWTAllowGroups != nil { - _ = json.Unmarshal(sJWTAllowGroups, &account.Settings.JWTAllowGroups) - } - if sNetworkRange != nil { - _ = json.Unmarshal(sNetworkRange, &account.Settings.NetworkRange) - } - - if sExtraPeerApprovalEnabled.Valid { - account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool - } - if sExtraUserApprovalRequired.Valid { - account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool - } - if sExtraIntegratedValidator.Valid { - account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String - } - if sExtraIntegratedValidatorGroups != nil { - _ = json.Unmarshal(sExtraIntegratedValidatorGroups, &account.Settings.Extra.IntegratedValidatorGroups) - } - - var wg sync.WaitGroup - errChan := make(chan error, 12) - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - - keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { - var sk types.SetupKey - var autoGroups []byte - var expiresAt, updatedAt, lastUsed sql.NullTime - var revoked, ephemeral, allowExtraDNSLabels sql.NullBool - var usedTimes, usageLimit sql.NullInt64 - - err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels) - - if err == nil { - if expiresAt.Valid { - sk.ExpiresAt = &expiresAt.Time - } - if updatedAt.Valid { - sk.UpdatedAt = updatedAt.Time - if sk.UpdatedAt.IsZero() { - sk.UpdatedAt = sk.CreatedAt - } - } - if lastUsed.Valid { - sk.LastUsed = &lastUsed.Time - } - if revoked.Valid { - sk.Revoked = revoked.Bool - } - if usedTimes.Valid { - sk.UsedTimes = int(usedTimes.Int64) - } - if usageLimit.Valid { - sk.UsageLimit = int(usageLimit.Int64) - } - if ephemeral.Valid { - sk.Ephemeral = ephemeral.Bool - } - if allowExtraDNSLabels.Valid { - sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool - } - if autoGroups != nil { - _ = json.Unmarshal(autoGroups, &sk.AutoGroups) - } else { - sk.AutoGroups = []string{} - } - } - return sk, err - }) - if err != nil { - errChan <- err - return - } - account.SetupKeysG = keys - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, location_geo_name_id FROM peers WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - - peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { - var p nbpeer.Peer - p.Status = &nbpeer.PeerStatus{} - var lastLogin sql.NullTime - var sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool - var peerStatusLastSeen sql.NullTime - var peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool - var ip, extraDNS, netAddr, env, flags, files, connIP []byte - - err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &p.CreatedAt, &ephemeral, &extraDNS, &allowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) - - if err == nil { - if lastLogin.Valid { - p.LastLogin = &lastLogin.Time - } - if sshEnabled.Valid { - p.SSHEnabled = sshEnabled.Bool - } - if loginExpirationEnabled.Valid { - p.LoginExpirationEnabled = loginExpirationEnabled.Bool - } - if inactivityExpirationEnabled.Valid { - p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool - } - if ephemeral.Valid { - p.Ephemeral = ephemeral.Bool - } - if allowExtraDNSLabels.Valid { - p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool - } - if peerStatusLastSeen.Valid { - p.Status.LastSeen = peerStatusLastSeen.Time - } - if peerStatusConnected.Valid { - p.Status.Connected = peerStatusConnected.Bool - } - if peerStatusLoginExpired.Valid { - p.Status.LoginExpired = peerStatusLoginExpired.Bool - } - if peerStatusRequiresApproval.Valid { - p.Status.RequiresApproval = peerStatusRequiresApproval.Bool - } - - if ip != nil { - _ = json.Unmarshal(ip, &p.IP) - } - if extraDNS != nil { - _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) - } - if netAddr != nil { - _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses) - } - if env != nil { - _ = json.Unmarshal(env, &p.Meta.Environment) - } - if flags != nil { - _ = json.Unmarshal(flags, &p.Meta.Flags) - } - if files != nil { - _ = json.Unmarshal(files, &p.Meta.Files) - } - if connIP != nil { - _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) - } - } - return p, err - }) - if err != nil { - errChan <- err - return - } - account.PeersG = peers - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { - var u types.User - var autoGroups []byte - var lastLogin sql.NullTime - var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool - err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) - if err == nil { - if lastLogin.Valid { - u.LastLogin = &lastLogin.Time - } - if isServiceUser.Valid { - u.IsServiceUser = isServiceUser.Bool - } - if nonDeletable.Valid { - u.NonDeletable = nonDeletable.Bool - } - if blocked.Valid { - u.Blocked = blocked.Bool - } - if pendingApproval.Valid { - u.PendingApproval = pendingApproval.Bool - } - if autoGroups != nil { - _ = json.Unmarshal(autoGroups, &u.AutoGroups) - } else { - u.AutoGroups = []string{} - } - } - return u, err - }) - if err != nil { - errChan <- err - return - } - account.UsersG = users - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { - var g types.Group - var resources []byte - var refID sql.NullInt64 - var refType sql.NullString - err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) - if err == nil { - if refID.Valid { - g.IntegrationReference.ID = int(refID.Int64) - } - if refType.Valid { - g.IntegrationReference.IntegrationType = refType.String - } - if resources != nil { - _ = json.Unmarshal(resources, &g.Resources) - } else { - g.Resources = []types.Resource{} - } - g.GroupPeers = []types.GroupPeer{} - g.Peers = []string{} - } - return &g, err - }) - if err != nil { - errChan <- err - return - } - account.GroupsG = groups - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { - var p types.Policy - var checks []byte - var enabled sql.NullBool - err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) - if err == nil { - if enabled.Valid { - p.Enabled = enabled.Bool - } - if checks != nil { - _ = json.Unmarshal(checks, &p.SourcePostureChecks) - } - } - return &p, err - }) - if err != nil { - errChan <- err - return - } - account.Policies = policies - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { - var r route.Route - var network, domains, peerGroups, groups, accessGroups []byte - var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool - var metric sql.NullInt64 - err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) - if err == nil { - if keepRoute.Valid { - r.KeepRoute = keepRoute.Bool - } - if masquerade.Valid { - r.Masquerade = masquerade.Bool - } - if enabled.Valid { - r.Enabled = enabled.Bool - } - if skipAutoApply.Valid { - r.SkipAutoApply = skipAutoApply.Bool - } - if metric.Valid { - r.Metric = int(metric.Int64) - } - if network != nil { - _ = json.Unmarshal(network, &r.Network) - } - if domains != nil { - _ = json.Unmarshal(domains, &r.Domains) - } - if peerGroups != nil { - _ = json.Unmarshal(peerGroups, &r.PeerGroups) - } - if groups != nil { - _ = json.Unmarshal(groups, &r.Groups) - } - if accessGroups != nil { - _ = json.Unmarshal(accessGroups, &r.AccessControlGroups) - } - } - return r, err - }) - if err != nil { - errChan <- err - return - } - account.RoutesG = routes - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { - var n nbdns.NameServerGroup - var ns, groups, domains []byte - var primary, enabled, searchDomainsEnabled sql.NullBool - err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) - if err == nil { - if primary.Valid { - n.Primary = primary.Bool - } - if enabled.Valid { - n.Enabled = enabled.Bool - } - if searchDomainsEnabled.Valid { - n.SearchDomainsEnabled = searchDomainsEnabled.Bool - } - if ns != nil { - _ = json.Unmarshal(ns, &n.NameServers) - } else { - n.NameServers = []nbdns.NameServer{} - } - if groups != nil { - _ = json.Unmarshal(groups, &n.Groups) - } else { - n.Groups = []string{} - } - if domains != nil { - _ = json.Unmarshal(domains, &n.Domains) - } else { - n.Domains = []string{} - } - } - return n, err - }) - if err != nil { - errChan <- err - return - } - account.NameServerGroupsG = nsgs - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { - var c posture.Checks - var checksDef []byte - err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) - if err == nil && checksDef != nil { - _ = json.Unmarshal(checksDef, &c.Checks) - } - return &c, err - }) - if err != nil { - errChan <- err - return - } - account.PostureChecks = checks - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network]) - if err != nil { - errChan <- err - return - } - account.Networks = make([]*networkTypes.Network, len(networks)) - for i := range networks { - account.Networks[i] = &networks[i] - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { - var r routerTypes.NetworkRouter - var peerGroups []byte - var masquerade, enabled sql.NullBool - var metric sql.NullInt64 - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) - if err == nil { - if masquerade.Valid { - r.Masquerade = masquerade.Bool - } - if enabled.Valid { - r.Enabled = enabled.Bool - } - if metric.Valid { - r.Metric = int(metric.Int64) - } - if peerGroups != nil { - _ = json.Unmarshal(peerGroups, &r.PeerGroups) - } - } - return r, err - }) - if err != nil { - errChan <- err - return - } - account.NetworkRouters = make([]*routerTypes.NetworkRouter, len(routers)) - for i := range routers { - account.NetworkRouters[i] = &routers[i] - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { - var r resourceTypes.NetworkResource - var prefix []byte - var enabled sql.NullBool - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) - if err == nil { - if enabled.Valid { - r.Enabled = enabled.Bool - } - if prefix != nil { - _ = json.Unmarshal(prefix, &r.Prefix) - } - } - return r, err - }) - if err != nil { - errChan <- err - return - } - account.NetworkResources = make([]*resourceTypes.NetworkResource, len(resources)) - for i := range resources { - account.NetworkResources[i] = &resources[i] - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` - var onboardingFlowPending, signupFormPending sql.NullBool - err := s.pool.QueryRow(ctx, query, accountID).Scan( - &account.Onboarding.AccountID, - &onboardingFlowPending, - &signupFormPending, - &account.Onboarding.CreatedAt, - &account.Onboarding.UpdatedAt, - ) - if err != nil && !errors.Is(err, pgx.ErrNoRows) { - errChan <- err - return - } - if onboardingFlowPending.Valid { - account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool - } - if signupFormPending.Valid { - account.Onboarding.SignupFormPending = signupFormPending.Bool - } - }() - - wg.Wait() - close(errChan) - for e := range errChan { - if e != nil { - return nil, e - } - } - - var userIDs []string - for _, u := range account.UsersG { - userIDs = append(userIDs, u.Id) - } - var policyIDs []string - for _, p := range account.Policies { - policyIDs = append(policyIDs, p.ID) - } - var groupIDs []string - for _, g := range account.GroupsG { - groupIDs = append(groupIDs, g.ID) - } - - wg.Add(3) - errChan = make(chan error, 3) - - var pats []types.PersonalAccessToken - go func() { - defer wg.Done() - if len(userIDs) == 0 { - return - } - const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)` - rows, err := s.pool.Query(ctx, query, userIDs) - if err != nil { - errChan <- err - return - } - pats, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { - var pat types.PersonalAccessToken - var expirationDate, lastUsed sql.NullTime - err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &pat.CreatedAt, &lastUsed) - if err == nil { - if expirationDate.Valid { - pat.ExpirationDate = &expirationDate.Time - } - if lastUsed.Valid { - pat.LastUsed = &lastUsed.Time - } - } - return pat, err - }) - if err != nil { - errChan <- err - } - }() - - var rules []*types.PolicyRule - go func() { - defer wg.Done() - if len(policyIDs) == 0 { - return - } - const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` - rows, err := s.pool.Query(ctx, query, policyIDs) - if err != nil { - errChan <- err - return - } - rules, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { - var r types.PolicyRule - var dest, destRes, sources, sourceRes, ports, portRanges []byte - var enabled, bidirectional sql.NullBool - err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) - if err == nil { - if enabled.Valid { - r.Enabled = enabled.Bool - } - if bidirectional.Valid { - r.Bidirectional = bidirectional.Bool - } - if dest != nil { - _ = json.Unmarshal(dest, &r.Destinations) - } - if destRes != nil { - _ = json.Unmarshal(destRes, &r.DestinationResource) - } - if sources != nil { - _ = json.Unmarshal(sources, &r.Sources) - } - if sourceRes != nil { - _ = json.Unmarshal(sourceRes, &r.SourceResource) - } - if ports != nil { - _ = json.Unmarshal(ports, &r.Ports) - } - if portRanges != nil { - _ = json.Unmarshal(portRanges, &r.PortRanges) - } - } - return &r, err - }) - if err != nil { - errChan <- err - } - }() - - var groupPeers []types.GroupPeer - go func() { - defer wg.Done() - if len(groupIDs) == 0 { - return - } - const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)` - rows, err := s.pool.Query(ctx, query, groupIDs) - if err != nil { - errChan <- err - return - } - groupPeers, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer]) - if err != nil { - errChan <- err - } - }() - - wg.Wait() - close(errChan) - for e := range errChan { - if e != nil { - return nil, e - } - } - - patsByUserID := make(map[string][]*types.PersonalAccessToken) - for i := range pats { - pat := &pats[i] - patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) - pat.UserID = "" - } - - rulesByPolicyID := make(map[string][]*types.PolicyRule) - for _, rule := range rules { - rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) - } - - peersByGroupID := make(map[string][]string) - for _, gp := range groupPeers { - peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) - } - - account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) - for i := range account.SetupKeysG { - key := &account.SetupKeysG[i] - account.SetupKeys[key.Key] = key - } - - account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) - for i := range account.PeersG { - peer := &account.PeersG[i] - account.Peers[peer.ID] = peer - } - - account.Users = make(map[string]*types.User, len(account.UsersG)) - for i := range account.UsersG { - user := &account.UsersG[i] - user.PATs = make(map[string]*types.PersonalAccessToken) - if userPats, ok := patsByUserID[user.Id]; ok { - for j := range userPats { - pat := userPats[j] - user.PATs[pat.ID] = pat - } - } - account.Users[user.Id] = user - } - - for i := range account.Policies { - policy := account.Policies[i] - if policyRules, ok := rulesByPolicyID[policy.ID]; ok { - policy.Rules = policyRules - } - } - - account.Groups = make(map[string]*types.Group, len(account.GroupsG)) - for i := range account.GroupsG { - group := account.GroupsG[i] - if peerIDs, ok := peersByGroupID[group.ID]; ok { - group.Peers = peerIDs - } - account.Groups[group.ID] = group - } - - account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) - for i := range account.RoutesG { - route := &account.RoutesG[i] - account.Routes[route.ID] = route - } - - account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) - for i := range account.NameServerGroupsG { - nsg := &account.NameServerGroupsG[i] - nsg.AccountID = "" - account.NameServerGroups[nsg.ID] = nsg - } - - account.SetupKeysG = nil - account.PeersG = nil - account.UsersG = nil - account.GroupsG = nil - account.RoutesG = nil - account.NameServerGroupsG = nil - - return &account, nil -} - -func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { elapsed := time.Since(start) @@ -1634,8 +786,7 @@ func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types. var account types.Account result := s.db.Model(&account). - // Omit("GroupsG"). - Preload("UsersG.PATsG"). // have to be specifies as this is nester reference + Preload("UsersG.PATsG"). Preload("Policies.Rules"). Preload("SetupKeysG"). Preload("PeersG"). @@ -1657,21 +808,14 @@ func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types. return nil, status.NewGetAccountFromStoreError(result.Error) } - // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us - // for i, policy := range account.Policies { - // var rules []*types.PolicyRule - // err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error - // if err != nil { - // return nil, status.Errorf(status.NotFound, "rule not found") - // } - // account.Policies[i].Rules = rules - // } - account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) for _, key := range account.SetupKeysG { if key.UpdatedAt.IsZero() { key.UpdatedAt = key.CreatedAt } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } account.SetupKeys[key.Key] = &key } account.SetupKeysG = nil @@ -1689,6 +833,9 @@ func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types. pat.UserID = "" user.PATs[pat.ID] = &pat } + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } account.Users[user.Id] = &user user.PATsG = nil } @@ -1700,21 +847,13 @@ func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types. for i, gp := range group.GroupPeers { group.Peers[i] = gp.PeerID } + if group.Resources == nil { + group.Resources = []types.Resource{} + } account.Groups[group.ID] = group } account.GroupsG = nil - // var groupPeers []types.GroupPeer - // s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). - // Find(&groupPeers) - // for _, groupPeer := range groupPeers { - // if group, ok := account.Groups[groupPeer.GroupID]; ok { - // group.Peers = append(group.Peers, groupPeer.PeerID) - // } else { - // log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) - // } - // } - account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) for _, route := range account.RoutesG { account.Routes[route.ID] = &route @@ -1724,6 +863,15 @@ func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types. account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) for _, ns := range account.NameServerGroupsG { ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } account.NameServerGroups[ns.ID] = &ns } account.NameServerGroupsG = nil @@ -2070,7 +1218,7 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe if err != nil { return nil, err } - pool, err := connectDB(context.Background(), dsn) + pool, err := connectToPgDb(context.Background(), dsn) if err != nil { return nil, err } @@ -2083,7 +1231,7 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe return store, nil } -func connectDB(ctx context.Context, dsn string) (*pgxpool.Pool, error) { +func connectToPgDb(ctx context.Context, dsn string) (*pgxpool.Pool, error) { config, err := pgxpool.ParseConfig(dsn) if err != nil { return nil, fmt.Errorf("unable to parse database config: %w", err) diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index 2b0256bd8..8c94f53c4 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -30,7 +30,6 @@ import ( "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/status" ) @@ -153,7 +152,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, string) { b.Fatalf("failed to connect database: %v", err) } - pool, err := connectDB(context.Background(), dsn) + pool, err := connectDBforTest(context.Background(), dsn) if err != nil { b.Fatalf("failed to connect database: %v", err) } @@ -434,7 +433,7 @@ func TestAccountEquivalence(t *testing.T) { expectedF getAccountFunc actualF getAccountFunc }{ - // {"old vs new", store.GetAccountSlow, store.GetAccount}, + {"old vs new", store.GetAccountSlow, store.GetAccount}, {"old vs raw", store.GetAccountSlow, store.GetAccountPureSQL}, } @@ -461,9 +460,6 @@ func TestAccountEquivalence(t *testing.T) { } func testAccountEquivalence(t *testing.T, expected, actual *types.Account) { - // normalizeNilSlices(expected) - // normalizeNilSlices(actual) - assert.Equal(t, expected.Id, actual.Id, "Account IDs should be equal") assert.Equal(t, expected.CreatedBy, actual.CreatedBy, "Account CreatedBy fields should be equal") assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second, "Account CreatedAt timestamps should be within a second") @@ -566,136 +562,290 @@ func testAccountEquivalence(t *testing.T, expected, actual *types.Account) { } } -func normalizeNilSlices(acc *types.Account) { - if acc == nil { - return +func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) { + account, err := s.getAccount(ctx, accountID) + if err != nil { + return nil, err } - if acc.Policies == nil { - acc.Policies = []*types.Policy{} - } - if acc.PostureChecks == nil { - acc.PostureChecks = []*posture.Checks{} - } - if acc.Networks == nil { - acc.Networks = []*networkTypes.Network{} - } - if acc.NetworkRouters == nil { - acc.NetworkRouters = []*routerTypes.NetworkRouter{} - } - if acc.NetworkResources == nil { - acc.NetworkResources = []*resourceTypes.NetworkResource{} - } - if acc.DNSSettings.DisabledManagementGroups == nil { - acc.DNSSettings.DisabledManagementGroups = []string{} - } + var wg sync.WaitGroup + errChan := make(chan error, 12) - for _, key := range acc.SetupKeys { - if key.AutoGroups == nil { - key.AutoGroups = []string{} + wg.Add(1) + go func() { + defer wg.Done() + keys, err := s.getSetupKeys(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + peers, err := s.getPeers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + users, err := s.getUsers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + groups, err := s.getGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + policies, err := s.getPolicies(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + routes, err := s.getRoutes(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + nsgs, err := s.getNameServerGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + checks, err := s.getPostureChecks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + networks, err := s.getNetworks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Networks = networks + }() + + wg.Add(1) + go func() { + defer wg.Done() + routers, err := s.getNetworkRouters(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + resources, err := s.getNetworkResources(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.getAccountOnboarding(ctx, accountID, account) + if err != nil { + errChan <- err + return + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e } } - for _, peer := range acc.Peers { - if peer.ExtraDNSLabels == nil { - peer.ExtraDNSLabels = []string{} + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + var err error + pats, err = s.getPersonalAccessTokens(ctx, userIDs) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + var err error + rules, err = s.getPolicyRules(ctx, policyIDs) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + var err error + groupPeers, err = s.getGroupPeers(ctx, groupIDs) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e } } - for _, user := range acc.Users { - if user.AutoGroups == nil { - user.AutoGroups = []string{} - } + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" } - for _, group := range acc.Groups { - if group.Peers == nil { - group.Peers = []string{} - } - if group.Resources == nil { - group.Resources = []types.Resource{} - } - if group.GroupPeers == nil { - group.GroupPeers = []types.GroupPeer{} - } + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) } - for _, route := range acc.Routes { - if route.Domains == nil { - route.Domains = domain.List{} - } - if route.PeerGroups == nil { - route.PeerGroups = []string{} - } - if route.Groups == nil { - route.Groups = []string{} - } - if route.AccessControlGroups == nil { - route.AccessControlGroups = []string{} - } + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) } - for _, nsg := range acc.NameServerGroups { - if nsg.NameServers == nil { - nsg.NameServers = []nbdns.NameServer{} - } - if nsg.Groups == nil { - nsg.Groups = []string{} - } - if nsg.Domains == nil { - nsg.Domains = []string{} - } + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key } - for _, policy := range acc.Policies { - if policy.SourcePostureChecks == nil { - policy.SourcePostureChecks = []string{} - } - if policy.Rules == nil { - policy.Rules = []*types.PolicyRule{} - } - for _, rule := range policy.Rules { - if rule.Destinations == nil { - rule.Destinations = []string{} - } - if rule.Sources == nil { - rule.Sources = []string{} - } - if rule.Ports == nil { - rule.Ports = []string{} - } - if rule.PortRanges == nil { - rule.PortRanges = []types.RulePortRange{} + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat } } + account.Users[user.Id] = user } - for _, check := range acc.PostureChecks { - if check.Checks.GeoLocationCheck != nil { - if check.Checks.GeoLocationCheck.Locations == nil { - check.Checks.GeoLocationCheck.Locations = []posture.Location{} - } - } - if check.Checks.PeerNetworkRangeCheck != nil { - if check.Checks.PeerNetworkRangeCheck.Ranges == nil { - check.Checks.PeerNetworkRangeCheck.Ranges = []netip.Prefix{} - } - } - if check.Checks.ProcessCheck != nil { - if check.Checks.ProcessCheck.Processes == nil { - check.Checks.ProcessCheck.Processes = []posture.Process{} - } + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules } } - for _, router := range acc.NetworkRouters { - if router.PeerGroups == nil { - router.PeerGroups = []string{} + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs } + account.Groups[group.ID] = group } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route + } + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg + } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil + account.NameServerGroupsG = nil + + return account, nil } -func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) { +func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Account, error) { var account types.Account account.Network = &types.Network{} const accountQuery = ` @@ -814,729 +964,566 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty if sExtraIntegratedValidatorGroups != nil { _ = json.Unmarshal(sExtraIntegratedValidatorGroups, &account.Settings.Extra.IntegratedValidatorGroups) } - - var wg sync.WaitGroup - errChan := make(chan error, 12) - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - - keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { - var sk types.SetupKey - var autoGroups []byte - var expiresAt, updatedAt, lastUsed sql.NullTime - var revoked, ephemeral, allowExtraDNSLabels sql.NullBool - var usedTimes, usageLimit sql.NullInt64 - - err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels) - - if err == nil { - if expiresAt.Valid { - sk.ExpiresAt = &expiresAt.Time - } - if updatedAt.Valid { - sk.UpdatedAt = updatedAt.Time - if sk.UpdatedAt.IsZero() { - sk.UpdatedAt = sk.CreatedAt - } - } - if lastUsed.Valid { - sk.LastUsed = &lastUsed.Time - } - if revoked.Valid { - sk.Revoked = revoked.Bool - } - if usedTimes.Valid { - sk.UsedTimes = int(usedTimes.Int64) - } - if usageLimit.Valid { - sk.UsageLimit = int(usageLimit.Int64) - } - if ephemeral.Valid { - sk.Ephemeral = ephemeral.Bool - } - if allowExtraDNSLabels.Valid { - sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool - } - if autoGroups != nil { - _ = json.Unmarshal(autoGroups, &sk.AutoGroups) - } else { - sk.AutoGroups = []string{} - } - } - return sk, err - }) - if err != nil { - errChan <- err - return - } - account.SetupKeysG = keys - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, location_geo_name_id FROM peers WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - - peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { - var p nbpeer.Peer - p.Status = &nbpeer.PeerStatus{} - var lastLogin sql.NullTime - var sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool - var peerStatusLastSeen sql.NullTime - var peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool - var ip, extraDNS, netAddr, env, flags, files, connIP []byte - - err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &p.CreatedAt, &ephemeral, &extraDNS, &allowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) - - if err == nil { - if lastLogin.Valid { - p.LastLogin = &lastLogin.Time - } - if sshEnabled.Valid { - p.SSHEnabled = sshEnabled.Bool - } - if loginExpirationEnabled.Valid { - p.LoginExpirationEnabled = loginExpirationEnabled.Bool - } - if inactivityExpirationEnabled.Valid { - p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool - } - if ephemeral.Valid { - p.Ephemeral = ephemeral.Bool - } - if allowExtraDNSLabels.Valid { - p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool - } - if peerStatusLastSeen.Valid { - p.Status.LastSeen = peerStatusLastSeen.Time - } - if peerStatusConnected.Valid { - p.Status.Connected = peerStatusConnected.Bool - } - if peerStatusLoginExpired.Valid { - p.Status.LoginExpired = peerStatusLoginExpired.Bool - } - if peerStatusRequiresApproval.Valid { - p.Status.RequiresApproval = peerStatusRequiresApproval.Bool - } - - if ip != nil { - _ = json.Unmarshal(ip, &p.IP) - } - if extraDNS != nil { - _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) - } - if netAddr != nil { - _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses) - } - if env != nil { - _ = json.Unmarshal(env, &p.Meta.Environment) - } - if flags != nil { - _ = json.Unmarshal(flags, &p.Meta.Flags) - } - if files != nil { - _ = json.Unmarshal(files, &p.Meta.Files) - } - if connIP != nil { - _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) - } - } - return p, err - }) - if err != nil { - errChan <- err - return - } - account.PeersG = peers - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { - var u types.User - var autoGroups []byte - var lastLogin sql.NullTime - var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool - err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) - if err == nil { - if lastLogin.Valid { - u.LastLogin = &lastLogin.Time - } - if isServiceUser.Valid { - u.IsServiceUser = isServiceUser.Bool - } - if nonDeletable.Valid { - u.NonDeletable = nonDeletable.Bool - } - if blocked.Valid { - u.Blocked = blocked.Bool - } - if pendingApproval.Valid { - u.PendingApproval = pendingApproval.Bool - } - if autoGroups != nil { - _ = json.Unmarshal(autoGroups, &u.AutoGroups) - } else { - u.AutoGroups = []string{} - } - } - return u, err - }) - if err != nil { - errChan <- err - return - } - account.UsersG = users - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { - var g types.Group - var resources []byte - var refID sql.NullInt64 - var refType sql.NullString - err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) - if err == nil { - if refID.Valid { - g.IntegrationReference.ID = int(refID.Int64) - } - if refType.Valid { - g.IntegrationReference.IntegrationType = refType.String - } - if resources != nil { - _ = json.Unmarshal(resources, &g.Resources) - } else { - g.Resources = []types.Resource{} - } - g.GroupPeers = []types.GroupPeer{} - g.Peers = []string{} - } - return &g, err - }) - if err != nil { - errChan <- err - return - } - account.GroupsG = groups - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { - var p types.Policy - var checks []byte - var enabled sql.NullBool - err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) - if err == nil { - if enabled.Valid { - p.Enabled = enabled.Bool - } - if checks != nil { - _ = json.Unmarshal(checks, &p.SourcePostureChecks) - } - } - return &p, err - }) - if err != nil { - errChan <- err - return - } - account.Policies = policies - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { - var r route.Route - var network, domains, peerGroups, groups, accessGroups []byte - var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool - var metric sql.NullInt64 - err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) - if err == nil { - if keepRoute.Valid { - r.KeepRoute = keepRoute.Bool - } - if masquerade.Valid { - r.Masquerade = masquerade.Bool - } - if enabled.Valid { - r.Enabled = enabled.Bool - } - if skipAutoApply.Valid { - r.SkipAutoApply = skipAutoApply.Bool - } - if metric.Valid { - r.Metric = int(metric.Int64) - } - if network != nil { - _ = json.Unmarshal(network, &r.Network) - } - if domains != nil { - _ = json.Unmarshal(domains, &r.Domains) - } - if peerGroups != nil { - _ = json.Unmarshal(peerGroups, &r.PeerGroups) - } - if groups != nil { - _ = json.Unmarshal(groups, &r.Groups) - } - if accessGroups != nil { - _ = json.Unmarshal(accessGroups, &r.AccessControlGroups) - } - } - return r, err - }) - if err != nil { - errChan <- err - return - } - account.RoutesG = routes - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { - var n nbdns.NameServerGroup - var ns, groups, domains []byte - var primary, enabled, searchDomainsEnabled sql.NullBool - err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) - if err == nil { - if primary.Valid { - n.Primary = primary.Bool - } - if enabled.Valid { - n.Enabled = enabled.Bool - } - if searchDomainsEnabled.Valid { - n.SearchDomainsEnabled = searchDomainsEnabled.Bool - } - if ns != nil { - _ = json.Unmarshal(ns, &n.NameServers) - } else { - n.NameServers = []nbdns.NameServer{} - } - if groups != nil { - _ = json.Unmarshal(groups, &n.Groups) - } else { - n.Groups = []string{} - } - if domains != nil { - _ = json.Unmarshal(domains, &n.Domains) - } else { - n.Domains = []string{} - } - } - return n, err - }) - if err != nil { - errChan <- err - return - } - account.NameServerGroupsG = nsgs - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { - var c posture.Checks - var checksDef []byte - err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) - if err == nil && checksDef != nil { - _ = json.Unmarshal(checksDef, &c.Checks) - } - return &c, err - }) - if err != nil { - errChan <- err - return - } - account.PostureChecks = checks - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network]) - if err != nil { - errChan <- err - return - } - account.Networks = make([]*networkTypes.Network, len(networks)) - for i := range networks { - account.Networks[i] = &networks[i] - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { - var r routerTypes.NetworkRouter - var peerGroups []byte - var masquerade, enabled sql.NullBool - var metric sql.NullInt64 - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) - if err == nil { - if masquerade.Valid { - r.Masquerade = masquerade.Bool - } - if enabled.Valid { - r.Enabled = enabled.Bool - } - if metric.Valid { - r.Metric = int(metric.Int64) - } - if peerGroups != nil { - _ = json.Unmarshal(peerGroups, &r.PeerGroups) - } - } - return r, err - }) - if err != nil { - errChan <- err - return - } - account.NetworkRouters = make([]*routerTypes.NetworkRouter, len(routers)) - for i := range routers { - account.NetworkRouters[i] = &routers[i] - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { - var r resourceTypes.NetworkResource - var prefix []byte - var enabled sql.NullBool - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) - if err == nil { - if enabled.Valid { - r.Enabled = enabled.Bool - } - if prefix != nil { - _ = json.Unmarshal(prefix, &r.Prefix) - } - } - return r, err - }) - if err != nil { - errChan <- err - return - } - account.NetworkResources = make([]*resourceTypes.NetworkResource, len(resources)) - for i := range resources { - account.NetworkResources[i] = &resources[i] - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` - var onboardingFlowPending, signupFormPending sql.NullBool - err := s.pool.QueryRow(ctx, query, accountID).Scan( - &account.Onboarding.AccountID, - &onboardingFlowPending, - &signupFormPending, - &account.Onboarding.CreatedAt, - &account.Onboarding.UpdatedAt, - ) - if err != nil && !errors.Is(err, pgx.ErrNoRows) { - errChan <- err - return - } - if onboardingFlowPending.Valid { - account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool - } - if signupFormPending.Valid { - account.Onboarding.SignupFormPending = signupFormPending.Bool - } - }() - - wg.Wait() - close(errChan) - for e := range errChan { - if e != nil { - return nil, e - } - } - - var userIDs []string - for _, u := range account.UsersG { - userIDs = append(userIDs, u.Id) - } - var policyIDs []string - for _, p := range account.Policies { - policyIDs = append(policyIDs, p.ID) - } - var groupIDs []string - for _, g := range account.GroupsG { - groupIDs = append(groupIDs, g.ID) - } - - wg.Add(3) - errChan = make(chan error, 3) - - var pats []types.PersonalAccessToken - go func() { - defer wg.Done() - if len(userIDs) == 0 { - return - } - const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)` - rows, err := s.pool.Query(ctx, query, userIDs) - if err != nil { - errChan <- err - return - } - pats, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { - var pat types.PersonalAccessToken - var expirationDate, lastUsed sql.NullTime - err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &pat.CreatedAt, &lastUsed) - if err == nil { - if expirationDate.Valid { - pat.ExpirationDate = &expirationDate.Time - } - if lastUsed.Valid { - pat.LastUsed = &lastUsed.Time - } - } - return pat, err - }) - if err != nil { - errChan <- err - } - }() - - var rules []*types.PolicyRule - go func() { - defer wg.Done() - if len(policyIDs) == 0 { - return - } - const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` - rows, err := s.pool.Query(ctx, query, policyIDs) - if err != nil { - errChan <- err - return - } - rules, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { - var r types.PolicyRule - var dest, destRes, sources, sourceRes, ports, portRanges []byte - var enabled, bidirectional sql.NullBool - err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) - if err == nil { - if enabled.Valid { - r.Enabled = enabled.Bool - } - if bidirectional.Valid { - r.Bidirectional = bidirectional.Bool - } - if dest != nil { - _ = json.Unmarshal(dest, &r.Destinations) - } - if destRes != nil { - _ = json.Unmarshal(destRes, &r.DestinationResource) - } - if sources != nil { - _ = json.Unmarshal(sources, &r.Sources) - } - if sourceRes != nil { - _ = json.Unmarshal(sourceRes, &r.SourceResource) - } - if ports != nil { - _ = json.Unmarshal(ports, &r.Ports) - } - if portRanges != nil { - _ = json.Unmarshal(portRanges, &r.PortRanges) - } - } - return &r, err - }) - if err != nil { - errChan <- err - } - }() - - var groupPeers []types.GroupPeer - go func() { - defer wg.Done() - if len(groupIDs) == 0 { - return - } - const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)` - rows, err := s.pool.Query(ctx, query, groupIDs) - if err != nil { - errChan <- err - return - } - groupPeers, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer]) - if err != nil { - errChan <- err - } - }() - - wg.Wait() - close(errChan) - for e := range errChan { - if e != nil { - return nil, e - } - } - - patsByUserID := make(map[string][]*types.PersonalAccessToken) - for i := range pats { - pat := &pats[i] - patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) - pat.UserID = "" - } - - rulesByPolicyID := make(map[string][]*types.PolicyRule) - for _, rule := range rules { - rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) - } - - peersByGroupID := make(map[string][]string) - for _, gp := range groupPeers { - peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) - } - - account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) - for i := range account.SetupKeysG { - key := &account.SetupKeysG[i] - account.SetupKeys[key.Key] = key - } - - account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) - for i := range account.PeersG { - peer := &account.PeersG[i] - account.Peers[peer.ID] = peer - } - - account.Users = make(map[string]*types.User, len(account.UsersG)) - for i := range account.UsersG { - user := &account.UsersG[i] - user.PATs = make(map[string]*types.PersonalAccessToken) - if userPats, ok := patsByUserID[user.Id]; ok { - for j := range userPats { - pat := userPats[j] - user.PATs[pat.ID] = pat - } - } - account.Users[user.Id] = user - } - - for i := range account.Policies { - policy := account.Policies[i] - if policyRules, ok := rulesByPolicyID[policy.ID]; ok { - policy.Rules = policyRules - } - } - - account.Groups = make(map[string]*types.Group, len(account.GroupsG)) - for i := range account.GroupsG { - group := account.GroupsG[i] - if peerIDs, ok := peersByGroupID[group.ID]; ok { - group.Peers = peerIDs - } - account.Groups[group.ID] = group - } - - account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) - for i := range account.RoutesG { - route := &account.RoutesG[i] - account.Routes[route.ID] = route - } - - account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) - for i := range account.NameServerGroupsG { - nsg := &account.NameServerGroupsG[i] - nsg.AccountID = "" - account.NameServerGroups[nsg.ID] = nsg - } - - account.SetupKeysG = nil - account.PeersG = nil - account.UsersG = nil - account.GroupsG = nil - account.RoutesG = nil - account.NameServerGroupsG = nil - return &account, nil } + +func (s *SqlStore) getSetupKeys(ctx context.Context, accountID string) ([]types.SetupKey, error) { + const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { + var sk types.SetupKey + var autoGroups []byte + var expiresAt, updatedAt, lastUsed sql.NullTime + var revoked, ephemeral, allowExtraDNSLabels sql.NullBool + var usedTimes, usageLimit sql.NullInt64 + + err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels) + + if err == nil { + if expiresAt.Valid { + sk.ExpiresAt = &expiresAt.Time + } + if updatedAt.Valid { + sk.UpdatedAt = updatedAt.Time + if sk.UpdatedAt.IsZero() { + sk.UpdatedAt = sk.CreatedAt + } + } + if lastUsed.Valid { + sk.LastUsed = &lastUsed.Time + } + if revoked.Valid { + sk.Revoked = revoked.Bool + } + if usedTimes.Valid { + sk.UsedTimes = int(usedTimes.Int64) + } + if usageLimit.Valid { + sk.UsageLimit = int(usageLimit.Int64) + } + if ephemeral.Valid { + sk.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &sk.AutoGroups) + } else { + sk.AutoGroups = []string{} + } + } + return sk, err + }) + if err != nil { + return nil, err + } + return keys, nil +} + +func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Peer, error) { + const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, location_geo_name_id FROM peers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { + var p nbpeer.Peer + p.Status = &nbpeer.PeerStatus{} + var lastLogin sql.NullTime + var sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool + var peerStatusLastSeen sql.NullTime + var peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool + var ip, extraDNS, netAddr, env, flags, files, connIP []byte + + err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &p.CreatedAt, &ephemeral, &extraDNS, &allowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) + + if err == nil { + if lastLogin.Valid { + p.LastLogin = &lastLogin.Time + } + if sshEnabled.Valid { + p.SSHEnabled = sshEnabled.Bool + } + if loginExpirationEnabled.Valid { + p.LoginExpirationEnabled = loginExpirationEnabled.Bool + } + if inactivityExpirationEnabled.Valid { + p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool + } + if ephemeral.Valid { + p.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if peerStatusLastSeen.Valid { + p.Status.LastSeen = peerStatusLastSeen.Time + } + if peerStatusConnected.Valid { + p.Status.Connected = peerStatusConnected.Bool + } + if peerStatusLoginExpired.Valid { + p.Status.LoginExpired = peerStatusLoginExpired.Bool + } + if peerStatusRequiresApproval.Valid { + p.Status.RequiresApproval = peerStatusRequiresApproval.Bool + } + + if ip != nil { + _ = json.Unmarshal(ip, &p.IP) + } + if extraDNS != nil { + _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) + } + if netAddr != nil { + _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses) + } + if env != nil { + _ = json.Unmarshal(env, &p.Meta.Environment) + } + if flags != nil { + _ = json.Unmarshal(flags, &p.Meta.Flags) + } + if files != nil { + _ = json.Unmarshal(files, &p.Meta.Files) + } + if connIP != nil { + _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) + } + } + return p, err + }) + if err != nil { + return nil, err + } + return peers, nil +} + +func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) { + const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { + var u types.User + var autoGroups []byte + var lastLogin sql.NullTime + var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool + err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) + if err == nil { + if lastLogin.Valid { + u.LastLogin = &lastLogin.Time + } + if isServiceUser.Valid { + u.IsServiceUser = isServiceUser.Bool + } + if nonDeletable.Valid { + u.NonDeletable = nonDeletable.Bool + } + if blocked.Valid { + u.Blocked = blocked.Bool + } + if pendingApproval.Valid { + u.PendingApproval = pendingApproval.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &u.AutoGroups) + } else { + u.AutoGroups = []string{} + } + } + return u, err + }) + if err != nil { + return nil, err + } + return users, nil +} + +func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) { + const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { + var g types.Group + var resources []byte + var refID sql.NullInt64 + var refType sql.NullString + err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) + if err == nil { + if refID.Valid { + g.IntegrationReference.ID = int(refID.Int64) + } + if refType.Valid { + g.IntegrationReference.IntegrationType = refType.String + } + if resources != nil { + _ = json.Unmarshal(resources, &g.Resources) + } else { + g.Resources = []types.Resource{} + } + g.GroupPeers = []types.GroupPeer{} + g.Peers = []string{} + } + return &g, err + }) + if err != nil { + return nil, err + } + return groups, nil +} + +func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) { + const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { + var p types.Policy + var checks []byte + var enabled sql.NullBool + err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) + if err == nil { + if enabled.Valid { + p.Enabled = enabled.Bool + } + if checks != nil { + _ = json.Unmarshal(checks, &p.SourcePostureChecks) + } + } + return &p, err + }) + if err != nil { + return nil, err + } + return policies, nil +} + +func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) { + const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { + var r route.Route + var network, domains, peerGroups, groups, accessGroups []byte + var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) + if err == nil { + if keepRoute.Valid { + r.KeepRoute = keepRoute.Bool + } + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if skipAutoApply.Valid { + r.SkipAutoApply = skipAutoApply.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if network != nil { + _ = json.Unmarshal(network, &r.Network) + } + if domains != nil { + _ = json.Unmarshal(domains, &r.Domains) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + if groups != nil { + _ = json.Unmarshal(groups, &r.Groups) + } + if accessGroups != nil { + _ = json.Unmarshal(accessGroups, &r.AccessControlGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + return routes, nil +} + +func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) { + const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { + var n nbdns.NameServerGroup + var ns, groups, domains []byte + var primary, enabled, searchDomainsEnabled sql.NullBool + err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) + if err == nil { + if primary.Valid { + n.Primary = primary.Bool + } + if enabled.Valid { + n.Enabled = enabled.Bool + } + if searchDomainsEnabled.Valid { + n.SearchDomainsEnabled = searchDomainsEnabled.Bool + } + if ns != nil { + _ = json.Unmarshal(ns, &n.NameServers) + } else { + n.NameServers = []nbdns.NameServer{} + } + if groups != nil { + _ = json.Unmarshal(groups, &n.Groups) + } else { + n.Groups = []string{} + } + if domains != nil { + _ = json.Unmarshal(domains, &n.Domains) + } else { + n.Domains = []string{} + } + } + return n, err + }) + if err != nil { + return nil, err + } + return nsgs, nil +} + +func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) { + const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { + var c posture.Checks + var checksDef []byte + err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) + if err == nil && checksDef != nil { + _ = json.Unmarshal(checksDef, &c.Checks) + } + return &c, err + }) + if err != nil { + return nil, err + } + return checks, nil +} + +func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) { + const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network]) + if err != nil { + return nil, err + } + result := make([]*networkTypes.Network, len(networks)) + for i := range networks { + result[i] = &networks[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) { + const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { + var r routerTypes.NetworkRouter + var peerGroups []byte + var masquerade, enabled sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) + if err == nil { + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*routerTypes.NetworkRouter, len(routers)) + for i := range routers { + result[i] = &routers[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) { + const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { + var r resourceTypes.NetworkResource + var prefix []byte + var enabled sql.NullBool + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if prefix != nil { + _ = json.Unmarshal(prefix, &r.Prefix) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*resourceTypes.NetworkResource, len(resources)) + for i := range resources { + result[i] = &resources[i] + } + return result, nil +} + +func (s *SqlStore) getAccountOnboarding(ctx context.Context, accountID string, account *types.Account) error { + const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` + var onboardingFlowPending, signupFormPending sql.NullBool + err := s.pool.QueryRow(ctx, query, accountID).Scan( + &account.Onboarding.AccountID, + &onboardingFlowPending, + &signupFormPending, + &account.Onboarding.CreatedAt, + &account.Onboarding.UpdatedAt, + ) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return err + } + if onboardingFlowPending.Valid { + account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool + } + if signupFormPending.Valid { + account.Onboarding.SignupFormPending = signupFormPending.Bool + } + return nil +} + +func (s *SqlStore) getPersonalAccessTokens(ctx context.Context, userIDs []string) ([]types.PersonalAccessToken, error) { + if len(userIDs) == 0 { + return nil, nil + } + const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, userIDs) + if err != nil { + return nil, err + } + pats, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { + var pat types.PersonalAccessToken + var expirationDate, lastUsed sql.NullTime + err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &pat.CreatedAt, &lastUsed) + if err == nil { + if expirationDate.Valid { + pat.ExpirationDate = &expirationDate.Time + } + if lastUsed.Valid { + pat.LastUsed = &lastUsed.Time + } + } + return pat, err + }) + if err != nil { + return nil, err + } + return pats, nil +} + +func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*types.PolicyRule, error) { + if len(policyIDs) == 0 { + return nil, nil + } + const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, policyIDs) + if err != nil { + return nil, err + } + rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { + var r types.PolicyRule + var dest, destRes, sources, sourceRes, ports, portRanges []byte + var enabled, bidirectional sql.NullBool + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if bidirectional.Valid { + r.Bidirectional = bidirectional.Bool + } + if dest != nil { + _ = json.Unmarshal(dest, &r.Destinations) + } + if destRes != nil { + _ = json.Unmarshal(destRes, &r.DestinationResource) + } + if sources != nil { + _ = json.Unmarshal(sources, &r.Sources) + } + if sourceRes != nil { + _ = json.Unmarshal(sourceRes, &r.SourceResource) + } + if ports != nil { + _ = json.Unmarshal(ports, &r.Ports) + } + if portRanges != nil { + _ = json.Unmarshal(portRanges, &r.PortRanges) + } + } + return &r, err + }) + if err != nil { + return nil, err + } + return rules, nil +} + +func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]types.GroupPeer, error) { + if len(groupIDs) == 0 { + return nil, nil + } + const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, groupIDs) + if err != nil { + return nil, err + } + groupPeers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer]) + if err != nil { + return nil, err + } + return groupPeers, nil +} From cdd2c97a461887c233937f9a440729752e2660c3 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Tue, 21 Oct 2025 13:46:10 +0200 Subject: [PATCH 11/14] get account test --- .../store/sql_store_get_account_test.go | 1089 +++++++++++++++++ 1 file changed, 1089 insertions(+) create mode 100644 management/server/store/sql_store_get_account_test.go diff --git a/management/server/store/sql_store_get_account_test.go b/management/server/store/sql_store_get_account_test.go new file mode 100644 index 000000000..8ff04d68a --- /dev/null +++ b/management/server/store/sql_store_get_account_test.go @@ -0,0 +1,1089 @@ +package store + +import ( + "context" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/integration_reference" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// TestGetAccount_ComprehensiveFieldValidation validates that GetAccount properly loads +// all fields and nested objects from the database, including deeply nested structures. +func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) { + if testing.Short() { + t.Skip("skipping comprehensive test in short mode") + } + + ctx := context.Background() + store, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + defer cleanup() + + // Create comprehensive test data + accountID := "test-account-comprehensive" + userID1 := "user-1" + userID2 := "user-2" + peerID1 := "peer-1" + peerID2 := "peer-2" + peerID3 := "peer-3" + groupID1 := "group-1" + groupID2 := "group-2" + setupKeyID1 := "setup-key-1" + setupKeyID2 := "setup-key-2" + routeID1 := route.ID("route-1") + routeID2 := route.ID("route-2") + nsGroupID1 := "ns-group-1" + nsGroupID2 := "ns-group-2" + policyID1 := "policy-1" + policyID2 := "policy-2" + postureCheckID1 := "posture-check-1" + postureCheckID2 := "posture-check-2" + networkID1 := "network-1" + routerID1 := "router-1" + resourceID1 := "resource-1" + patID1 := "pat-1" + patID2 := "pat-2" + patID3 := "pat-3" + + now := time.Now().UTC().Truncate(time.Second) + lastLogin := now.Add(-24 * time.Hour) + patLastUsed := now.Add(-1 * time.Hour) + + // Build comprehensive account with all fields populated + account := &types.Account{ + Id: accountID, + CreatedBy: userID1, + CreatedAt: now, + Domain: "example.com", + DomainCategory: "business", + IsDomainPrimaryAccount: true, + Network: &types.Network{ + Identifier: "test-network", + Net: net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }, + Dns: "test-dns", + Serial: 42, + }, + DNSSettings: types.DNSSettings{ + DisabledManagementGroups: []string{"dns-group-1", "dns-group-2"}, + }, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: time.Hour * 24 * 30, + GroupsPropagationEnabled: true, + JWTGroupsEnabled: true, + JWTGroupsClaimName: "groups", + JWTAllowGroups: []string{"allowed-group-1", "allowed-group-2"}, + RegularUsersViewBlocked: false, + Extra: &types.ExtraSettings{ + PeerApprovalEnabled: true, + IntegratedValidatorGroups: []string{"validator-1"}, + }, + }, + } + + // Create Setup Keys with all fields + setupKey1ExpiresAt := now.Add(30 * 24 * time.Hour) + setupKey1LastUsed := now.Add(-2 * time.Hour) + setupKey1 := &types.SetupKey{ + Id: setupKeyID1, + AccountID: accountID, + Key: "setup-key-secret-1", + Name: "Setup Key 1", + Type: types.SetupKeyReusable, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: &setupKey1ExpiresAt, + Revoked: false, + UsedTimes: 5, + LastUsed: &setupKey1LastUsed, + AutoGroups: []string{groupID1, groupID2}, + UsageLimit: 100, + Ephemeral: false, + } + + setupKey2ExpiresAt := now.Add(7 * 24 * time.Hour) + setupKey2LastUsed := now.Add(-1 * time.Hour) + setupKey2 := &types.SetupKey{ + Id: setupKeyID2, + AccountID: accountID, + Key: "setup-key-secret-2", + Name: "Setup Key 2 (One-off)", + Type: types.SetupKeyOneOff, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: &setupKey2ExpiresAt, + Revoked: true, + UsedTimes: 1, + LastUsed: &setupKey2LastUsed, + AutoGroups: []string{}, + UsageLimit: 1, + Ephemeral: true, + } + + account.SetupKeys = map[string]*types.SetupKey{ + setupKey1.Key: setupKey1, + setupKey2.Key: setupKey2, + } + + // Create Peers with comprehensive fields + peer1 := &nbpeer.Peer{ + ID: peerID1, + AccountID: accountID, + Key: "peer-key-1-AAAA", + Name: "Peer 1", + IP: net.ParseIP("100.64.0.1"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer1.example.com", + GoOS: "linux", + Kernel: "5.15.0", + Core: "x86_64", + Platform: "ubuntu", + OS: "Ubuntu 22.04", + WtVersion: "0.24.0", + UIVersion: "0.24.0", + KernelVersion: "5.15.0-78-generic", + OSVersion: "22.04", + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.1.10/32"), Mac: "00:11:22:33:44:55"}, + {NetIP: netip.MustParsePrefix("10.0.0.5/32"), Mac: "00:11:22:33:44:66"}, + }, + SystemSerialNumber: "ABC123", + SystemProductName: "Server Model X", + SystemManufacturer: "Dell Inc.", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-5 * time.Minute), + Connected: true, + LoginExpired: false, + RequiresApproval: false, + }, + Location: nbpeer.Location{ + ConnectionIP: net.ParseIP("203.0.113.10"), + CountryCode: "US", + CityName: "San Francisco", + GeoNameID: 5391959, + }, + SSHEnabled: true, + SSHKey: "ssh-rsa AAAAB3NzaC1...", + UserID: userID1, + LoginExpirationEnabled: true, + InactivityExpirationEnabled: false, + DNSLabel: "peer1", + CreatedAt: now.Add(-30 * 24 * time.Hour), + Ephemeral: false, + } + + peer2 := &nbpeer.Peer{ + ID: peerID2, + AccountID: accountID, + Key: "peer-key-2-BBBB", + Name: "Peer 2", + IP: net.ParseIP("100.64.0.2"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer2.example.com", + GoOS: "darwin", + Kernel: "22.0.0", + Core: "arm64", + Platform: "darwin", + OS: "macOS Ventura", + WtVersion: "0.24.0", + UIVersion: "0.24.0", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-1 * time.Hour), + Connected: false, + LoginExpired: true, + RequiresApproval: true, + }, + Location: nbpeer.Location{ + ConnectionIP: net.ParseIP("198.51.100.20"), + CountryCode: "GB", + CityName: "London", + GeoNameID: 2643743, + }, + SSHEnabled: false, + UserID: userID2, + LoginExpirationEnabled: false, + InactivityExpirationEnabled: true, + DNSLabel: "peer2", + CreatedAt: now.Add(-15 * 24 * time.Hour), + Ephemeral: false, + } + + peer3 := &nbpeer.Peer{ + ID: peerID3, + AccountID: accountID, + Key: "peer-key-3-CCCC", + Name: "Peer 3 (Ephemeral)", + IP: net.ParseIP("100.64.0.3"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer3.example.com", + GoOS: "windows", + Platform: "windows", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-10 * time.Minute), + Connected: true, + }, + DNSLabel: "peer3", + CreatedAt: now.Add(-1 * time.Hour), + Ephemeral: true, + } + + account.Peers = map[string]*nbpeer.Peer{ + peerID1: peer1, + peerID2: peer2, + peerID3: peer3, + } + + // Create Users with PATs + pat1ExpirationDate := now.Add(90 * 24 * time.Hour) + pat1 := &types.PersonalAccessToken{ + ID: patID1, + Name: "PAT 1", + HashedToken: "hashed-token-1", + ExpirationDate: &pat1ExpirationDate, + CreatedAt: now.Add(-10 * 24 * time.Hour), + CreatedBy: userID1, + LastUsed: &patLastUsed, + } + + pat2ExpirationDate := now.Add(30 * 24 * time.Hour) + pat2 := &types.PersonalAccessToken{ + ID: patID2, + Name: "PAT 2", + HashedToken: "hashed-token-2", + ExpirationDate: &pat2ExpirationDate, + CreatedAt: now.Add(-5 * 24 * time.Hour), + CreatedBy: userID1, + } + + pat3ExpirationDate := now.Add(60 * 24 * time.Hour) + pat3 := &types.PersonalAccessToken{ + ID: patID3, + Name: "PAT 3", + HashedToken: "hashed-token-3", + ExpirationDate: &pat3ExpirationDate, + CreatedAt: now.Add(-2 * 24 * time.Hour), + CreatedBy: userID2, + } + + user1 := &types.User{ + Id: userID1, + AccountID: accountID, + Role: types.UserRoleOwner, + IsServiceUser: false, + NonDeletable: true, + AutoGroups: []string{groupID1}, + Issued: types.UserIssuedAPI, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 123, + IntegrationType: "azure_ad", + }, + CreatedAt: now.Add(-60 * 24 * time.Hour), + LastLogin: &lastLogin, + Blocked: false, + PATs: map[string]*types.PersonalAccessToken{ + patID1: pat1, + patID2: pat2, + }, + } + + user2 := &types.User{ + Id: userID2, + AccountID: accountID, + Role: types.UserRoleAdmin, + IsServiceUser: true, + NonDeletable: false, + AutoGroups: []string{groupID2}, + Issued: types.UserIssuedIntegration, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 456, + IntegrationType: "google_workspace", + }, + CreatedAt: now.Add(-30 * 24 * time.Hour), + Blocked: false, + PATs: map[string]*types.PersonalAccessToken{ + patID3: pat3, + }, + } + + account.Users = map[string]*types.User{ + userID1: user1, + userID2: user2, + } + + // Create Groups with peers and resources + group1 := &types.Group{ + ID: groupID1, + AccountID: accountID, + Name: "Group 1", + Issued: types.GroupIssuedAPI, + Peers: []string{peerID1, peerID2}, + Resources: []types.Resource{ + { + ID: "resource-1", + Type: types.ResourceTypeHost, + }, + }, + } + + group2 := &types.Group{ + ID: groupID2, + AccountID: accountID, + Name: "Group 2", + Issued: types.GroupIssuedIntegration, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 789, + IntegrationType: "okta", + }, + Peers: []string{peerID3}, + Resources: []types.Resource{}, + } + + account.Groups = map[string]*types.Group{ + groupID1: group1, + groupID2: group2, + } + + // Create Policies with Rules + policy1 := &types.Policy{ + ID: policyID1, + AccountID: accountID, + Name: "Policy 1", + Description: "Main access policy", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "rule-1", + PolicyID: policyID1, + Name: "Rule 1", + Description: "Allow access", + Enabled: true, + Action: types.PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Ports: []string{}, + PortRanges: []types.RulePortRange{}, + Sources: []string{groupID1}, + Destinations: []string{groupID2}, + }, + { + ID: "rule-2", + PolicyID: policyID1, + Name: "Rule 2", + Description: "Block traffic on specific ports", + Enabled: true, + Action: types.PolicyTrafficActionDrop, + Bidirectional: false, + Protocol: types.PolicyRuleProtocolTCP, + Ports: []string{"22", "3389"}, + PortRanges: []types.RulePortRange{ + {Start: 8000, End: 8999}, + }, + Sources: []string{groupID2}, + Destinations: []string{groupID1}, + }, + }, + } + + policy2 := &types.Policy{ + ID: policyID2, + AccountID: accountID, + Name: "Policy 2", + Description: "Secondary policy", + Enabled: false, + Rules: []*types.PolicyRule{ + { + ID: "rule-3", + PolicyID: policyID2, + Name: "Rule 3", + Description: "UDP access", + Enabled: false, + Action: types.PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolUDP, + Ports: []string{"53"}, + Sources: []string{groupID1}, + Destinations: []string{groupID1}, + }, + }, + } + + account.Policies = []*types.Policy{policy1, policy2} + + // Create Routes + route1 := &route.Route{ + ID: routeID1, + AccountID: accountID, + Network: netip.MustParsePrefix("10.0.0.0/24"), + NetworkType: route.IPv4Network, + Peer: peerID1, + PeerGroups: []string{}, + Description: "Route 1", + NetID: "net-id-1", + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{groupID1}, + AccessControlGroups: []string{groupID2}, + } + + route2 := &route.Route{ + ID: routeID2, + AccountID: accountID, + Network: netip.MustParsePrefix("192.168.1.0/24"), + NetworkType: route.IPv4Network, + Peer: "", + PeerGroups: []string{groupID2}, + Description: "Route 2 (High Availability)", + NetID: "net-id-2", + Masquerade: false, + Metric: 100, + Enabled: true, + Groups: []string{groupID1, groupID2}, + AccessControlGroups: []string{groupID1}, + } + + account.Routes = map[route.ID]*route.Route{ + routeID1: route1, + routeID2: route2, + } + + // Create NameServer Groups + nsGroup1 := &nbdns.NameServerGroup{ + ID: nsGroupID1, + AccountID: accountID, + Name: "NS Group 1", + Description: "Primary nameservers", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + { + IP: netip.MustParseAddr("8.8.4.4"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + Groups: []string{groupID1, groupID2}, + Domains: []string{"example.com", "test.com"}, + Enabled: true, + Primary: true, + SearchDomainsEnabled: true, + } + + nsGroup2 := &nbdns.NameServerGroup{ + ID: nsGroupID2, + AccountID: accountID, + Name: "NS Group 2", + Description: "Secondary nameservers", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + Groups: []string{}, + Domains: []string{}, + Enabled: false, + Primary: false, + SearchDomainsEnabled: false, + } + + account.NameServerGroups = map[string]*nbdns.NameServerGroup{ + nsGroupID1: nsGroup1, + nsGroupID2: nsGroup2, + } + + // Create Posture Checks + postureCheck1 := &posture.Checks{ + ID: postureCheckID1, + AccountID: accountID, + Name: "Posture Check 1", + Description: "OS version check", + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.24.0", + }, + OSVersionCheck: &posture.OSVersionCheck{ + Ios: &posture.MinVersionCheck{ + MinVersion: "16.0", + }, + Darwin: &posture.MinVersionCheck{ + MinVersion: "22.0.0", + }, + }, + }, + } + + postureCheck2 := &posture.Checks{ + ID: postureCheckID2, + AccountID: accountID, + Name: "Posture Check 2", + Description: "Geo location check", + Checks: posture.ChecksDefinition{ + GeoLocationCheck: &posture.GeoLocationCheck{ + Locations: []posture.Location{ + { + CountryCode: "US", + CityName: "San Francisco", + }, + { + CountryCode: "GB", + CityName: "London", + }, + }, + Action: "allow", + }, + PeerNetworkRangeCheck: &posture.PeerNetworkRangeCheck{ + Ranges: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + Action: "allow", + }, + }, + } + + account.PostureChecks = []*posture.Checks{postureCheck1, postureCheck2} + + // Create Networks + network1 := &networkTypes.Network{ + ID: networkID1, + AccountID: accountID, + Name: "Network 1", + Description: "Primary network", + } + + account.Networks = []*networkTypes.Network{network1} + + // Create Network Routers + router1 := &routerTypes.NetworkRouter{ + ID: routerID1, + AccountID: accountID, + NetworkID: networkID1, + Peer: peerID1, + PeerGroups: []string{}, + Masquerade: true, + Metric: 100, + } + + account.NetworkRouters = []*routerTypes.NetworkRouter{router1} + + // Create Network Resources + resource1 := &resourceTypes.NetworkResource{ + ID: resourceID1, + AccountID: accountID, + NetworkID: networkID1, + Name: "Resource 1", + Description: "Web server", + Prefix: netip.MustParsePrefix("192.168.1.100/32"), + Type: resourceTypes.Host, + } + + account.NetworkResources = []*resourceTypes.NetworkResource{resource1} + + // Create Onboarding + account.Onboarding = types.AccountOnboarding{ + AccountID: accountID, + OnboardingFlowPending: true, + SignupFormPending: false, + CreatedAt: now, + UpdatedAt: now, + } + + // Save the account to the database + err = store.SaveAccount(ctx, account) + require.NoError(t, err, "Failed to save comprehensive test account") + + // Retrieve the account from the database + retrievedAccount, err := store.GetAccount(ctx, accountID) + require.NoError(t, err, "Failed to retrieve account") + require.NotNil(t, retrievedAccount, "Retrieved account should not be nil") + + // ========== VALIDATE TOP-LEVEL FIELDS ========== + t.Run("TopLevelFields", func(t *testing.T) { + assert.Equal(t, accountID, retrievedAccount.Id, "Account ID mismatch") + assert.Equal(t, userID1, retrievedAccount.CreatedBy, "CreatedBy mismatch") + assert.WithinDuration(t, now, retrievedAccount.CreatedAt, time.Second, "CreatedAt mismatch") + assert.Equal(t, "example.com", retrievedAccount.Domain, "Domain mismatch") + assert.Equal(t, "business", retrievedAccount.DomainCategory, "DomainCategory mismatch") + assert.True(t, retrievedAccount.IsDomainPrimaryAccount, "IsDomainPrimaryAccount should be true") + }) + + // ========== VALIDATE EMBEDDED NETWORK ========== + t.Run("EmbeddedNetwork", func(t *testing.T) { + require.NotNil(t, retrievedAccount.Network, "Network should not be nil") + assert.Equal(t, "test-network", retrievedAccount.Network.Identifier, "Network Identifier mismatch") + assert.Equal(t, "test-dns", retrievedAccount.Network.Dns, "Network DNS mismatch") + assert.Equal(t, uint64(42), retrievedAccount.Network.Serial, "Network Serial mismatch") + + expectedIP := net.ParseIP("100.64.0.0") + assert.True(t, retrievedAccount.Network.Net.IP.Equal(expectedIP), "Network IP mismatch") + expectedMask := net.CIDRMask(10, 32) + assert.Equal(t, expectedMask, retrievedAccount.Network.Net.Mask, "Network Mask mismatch") + }) + + // ========== VALIDATE DNS SETTINGS ========== + t.Run("DNSSettings", func(t *testing.T) { + assert.Len(t, retrievedAccount.DNSSettings.DisabledManagementGroups, 2, "DisabledManagementGroups length mismatch") + assert.Contains(t, retrievedAccount.DNSSettings.DisabledManagementGroups, "dns-group-1", "Missing dns-group-1") + assert.Contains(t, retrievedAccount.DNSSettings.DisabledManagementGroups, "dns-group-2", "Missing dns-group-2") + }) + + // ========== VALIDATE SETTINGS ========== + t.Run("Settings", func(t *testing.T) { + require.NotNil(t, retrievedAccount.Settings, "Settings should not be nil") + assert.True(t, retrievedAccount.Settings.PeerLoginExpirationEnabled, "PeerLoginExpirationEnabled mismatch") + assert.Equal(t, time.Hour*24*30, retrievedAccount.Settings.PeerLoginExpiration, "PeerLoginExpiration mismatch") + assert.True(t, retrievedAccount.Settings.GroupsPropagationEnabled, "GroupsPropagationEnabled mismatch") + assert.True(t, retrievedAccount.Settings.JWTGroupsEnabled, "JWTGroupsEnabled mismatch") + assert.Equal(t, "groups", retrievedAccount.Settings.JWTGroupsClaimName, "JWTGroupsClaimName mismatch") + assert.Len(t, retrievedAccount.Settings.JWTAllowGroups, 2, "JWTAllowGroups length mismatch") + assert.Contains(t, retrievedAccount.Settings.JWTAllowGroups, "allowed-group-1") + assert.Contains(t, retrievedAccount.Settings.JWTAllowGroups, "allowed-group-2") + assert.False(t, retrievedAccount.Settings.RegularUsersViewBlocked, "RegularUsersViewBlocked mismatch") + + // Validate Extra Settings + require.NotNil(t, retrievedAccount.Settings.Extra, "Extra settings should not be nil") + assert.True(t, retrievedAccount.Settings.Extra.PeerApprovalEnabled, "PeerApprovalEnabled mismatch") + assert.Len(t, retrievedAccount.Settings.Extra.IntegratedValidatorGroups, 1, "IntegratedValidatorGroups length mismatch") + assert.Equal(t, "validator-1", retrievedAccount.Settings.Extra.IntegratedValidatorGroups[0]) + }) + + // ========== VALIDATE SETUP KEYS ========== + t.Run("SetupKeys", func(t *testing.T) { + require.Len(t, retrievedAccount.SetupKeys, 2, "Should have 2 setup keys") + + // Validate Setup Key 1 + sk1, exists := retrievedAccount.SetupKeys["setup-key-secret-1"] + require.True(t, exists, "Setup key 1 should exist") + assert.Equal(t, "Setup Key 1", sk1.Name, "Setup key 1 name mismatch") + assert.Equal(t, types.SetupKeyReusable, sk1.Type, "Setup key 1 type mismatch") + assert.False(t, sk1.Revoked, "Setup key 1 should not be revoked") + assert.Equal(t, 5, sk1.UsedTimes, "Setup key 1 used times mismatch") + assert.Equal(t, 100, sk1.UsageLimit, "Setup key 1 usage limit mismatch") + assert.False(t, sk1.Ephemeral, "Setup key 1 should not be ephemeral") + assert.Len(t, sk1.AutoGroups, 2, "Setup key 1 auto groups length mismatch") + assert.Contains(t, sk1.AutoGroups, groupID1) + assert.Contains(t, sk1.AutoGroups, groupID2) + + // Validate Setup Key 2 + sk2, exists := retrievedAccount.SetupKeys["setup-key-secret-2"] + require.True(t, exists, "Setup key 2 should exist") + assert.Equal(t, "Setup Key 2 (One-off)", sk2.Name, "Setup key 2 name mismatch") + assert.Equal(t, types.SetupKeyOneOff, sk2.Type, "Setup key 2 type mismatch") + assert.True(t, sk2.Revoked, "Setup key 2 should be revoked") + assert.Equal(t, 1, sk2.UsedTimes, "Setup key 2 used times mismatch") + assert.Equal(t, 1, sk2.UsageLimit, "Setup key 2 usage limit mismatch") + assert.True(t, sk2.Ephemeral, "Setup key 2 should be ephemeral") + assert.Len(t, sk2.AutoGroups, 0, "Setup key 2 should have empty auto groups") + }) + + // ========== VALIDATE PEERS ========== + t.Run("Peers", func(t *testing.T) { + require.Len(t, retrievedAccount.Peers, 3, "Should have 3 peers") + + // Validate Peer 1 + p1, exists := retrievedAccount.Peers[peerID1] + require.True(t, exists, "Peer 1 should exist") + assert.Equal(t, "Peer 1", p1.Name, "Peer 1 name mismatch") + assert.Equal(t, "peer-key-1-AAAA", p1.Key, "Peer 1 key mismatch") + assert.True(t, p1.IP.Equal(net.ParseIP("100.64.0.1")), "Peer 1 IP mismatch") + assert.Equal(t, userID1, p1.UserID, "Peer 1 user ID mismatch") + assert.True(t, p1.SSHEnabled, "Peer 1 SSH should be enabled") + assert.Equal(t, "ssh-rsa AAAAB3NzaC1...", p1.SSHKey, "Peer 1 SSH key mismatch") + assert.True(t, p1.LoginExpirationEnabled, "Peer 1 login expiration should be enabled") + assert.False(t, p1.Ephemeral, "Peer 1 should not be ephemeral") + assert.Equal(t, "peer1", p1.DNSLabel, "Peer 1 DNS label mismatch") + + // Validate Peer 1 Meta + assert.Equal(t, "peer1.example.com", p1.Meta.Hostname, "Peer 1 hostname mismatch") + assert.Equal(t, "linux", p1.Meta.GoOS, "Peer 1 OS mismatch") + assert.Equal(t, "5.15.0", p1.Meta.Kernel, "Peer 1 kernel mismatch") + assert.Equal(t, "x86_64", p1.Meta.Core, "Peer 1 core mismatch") + assert.Equal(t, "ubuntu", p1.Meta.Platform, "Peer 1 platform mismatch") + assert.Equal(t, "Ubuntu 22.04", p1.Meta.OS, "Peer 1 OS version mismatch") + assert.Equal(t, "0.24.0", p1.Meta.WtVersion, "Peer 1 wt version mismatch") + assert.Equal(t, "ABC123", p1.Meta.SystemSerialNumber, "Peer 1 serial number mismatch") + assert.Equal(t, "Server Model X", p1.Meta.SystemProductName, "Peer 1 product name mismatch") + assert.Equal(t, "Dell Inc.", p1.Meta.SystemManufacturer, "Peer 1 manufacturer mismatch") + + // Validate Network Addresses + assert.Len(t, p1.Meta.NetworkAddresses, 2, "Peer 1 should have 2 network addresses") + assert.Equal(t, netip.MustParsePrefix("192.168.1.10/32"), p1.Meta.NetworkAddresses[0].NetIP, "Network address 1 IP mismatch") + assert.Equal(t, "00:11:22:33:44:55", p1.Meta.NetworkAddresses[0].Mac, "Network address 1 MAC mismatch") + assert.Equal(t, netip.MustParsePrefix("10.0.0.5/32"), p1.Meta.NetworkAddresses[1].NetIP, "Network address 2 IP mismatch") + assert.Equal(t, "00:11:22:33:44:66", p1.Meta.NetworkAddresses[1].Mac, "Network address 2 MAC mismatch") + + // Validate Peer 1 Status + require.NotNil(t, p1.Status, "Peer 1 status should not be nil") + assert.True(t, p1.Status.Connected, "Peer 1 should be connected") + assert.False(t, p1.Status.LoginExpired, "Peer 1 login should not be expired") + assert.False(t, p1.Status.RequiresApproval, "Peer 1 should not require approval") + + // Validate Peer 1 Location + assert.True(t, p1.Location.ConnectionIP.Equal(net.ParseIP("203.0.113.10")), "Peer 1 connection IP mismatch") + assert.Equal(t, "US", p1.Location.CountryCode, "Peer 1 country code mismatch") + assert.Equal(t, "San Francisco", p1.Location.CityName, "Peer 1 city name mismatch") + assert.Equal(t, uint(5391959), p1.Location.GeoNameID, "Peer 1 geo name ID mismatch") + + // Validate Peer 2 + p2, exists := retrievedAccount.Peers[peerID2] + require.True(t, exists, "Peer 2 should exist") + assert.Equal(t, "Peer 2", p2.Name, "Peer 2 name mismatch") + assert.Equal(t, "peer-key-2-BBBB", p2.Key, "Peer 2 key mismatch") + assert.False(t, p2.SSHEnabled, "Peer 2 SSH should be disabled") + assert.False(t, p2.LoginExpirationEnabled, "Peer 2 login expiration should be disabled") + assert.True(t, p2.InactivityExpirationEnabled, "Peer 2 inactivity expiration should be enabled") + + // Validate Peer 2 Status + require.NotNil(t, p2.Status, "Peer 2 status should not be nil") + assert.False(t, p2.Status.Connected, "Peer 2 should not be connected") + assert.True(t, p2.Status.LoginExpired, "Peer 2 login should be expired") + assert.True(t, p2.Status.RequiresApproval, "Peer 2 should require approval") + + // Validate Peer 3 (Ephemeral) + p3, exists := retrievedAccount.Peers[peerID3] + require.True(t, exists, "Peer 3 should exist") + assert.True(t, p3.Ephemeral, "Peer 3 should be ephemeral") + assert.Equal(t, "Peer 3 (Ephemeral)", p3.Name, "Peer 3 name mismatch") + }) + + // ========== VALIDATE USERS ========== + t.Run("Users", func(t *testing.T) { + require.Len(t, retrievedAccount.Users, 2, "Should have 2 users") + + // Validate User 1 + u1, exists := retrievedAccount.Users[userID1] + require.True(t, exists, "User 1 should exist") + assert.Equal(t, types.UserRoleOwner, u1.Role, "User 1 role mismatch") + assert.False(t, u1.IsServiceUser, "User 1 should not be a service user") + assert.True(t, u1.NonDeletable, "User 1 should be non-deletable") + assert.Equal(t, types.UserIssuedAPI, u1.Issued, "User 1 issued type mismatch") + assert.Len(t, u1.AutoGroups, 1, "User 1 auto groups length mismatch") + assert.Contains(t, u1.AutoGroups, groupID1, "User 1 should have group1") + assert.False(t, u1.Blocked, "User 1 should not be blocked") + require.NotNil(t, u1.LastLogin, "User 1 last login should not be nil") + assert.WithinDuration(t, lastLogin, *u1.LastLogin, time.Second, "User 1 last login mismatch") + + // Validate User 1 Integration Reference + assert.Equal(t, 123, u1.IntegrationReference.ID, "User 1 integration ID mismatch") + assert.Equal(t, "azure_ad", u1.IntegrationReference.IntegrationType, "User 1 integration type mismatch") + + // Validate User 1 PATs + require.Len(t, u1.PATs, 2, "User 1 should have 2 PATs") + + pat1Retrieved, exists := u1.PATs[patID1] + require.True(t, exists, "PAT 1 should exist") + assert.Equal(t, "PAT 1", pat1Retrieved.Name, "PAT 1 name mismatch") + assert.Equal(t, "hashed-token-1", pat1Retrieved.HashedToken, "PAT 1 hashed token mismatch") + require.NotNil(t, pat1Retrieved.LastUsed, "PAT 1 last used should not be nil") + assert.WithinDuration(t, patLastUsed, *pat1Retrieved.LastUsed, time.Second, "PAT 1 last used mismatch") + assert.Equal(t, userID1, pat1Retrieved.CreatedBy, "PAT 1 created by mismatch") + assert.Empty(t, pat1Retrieved.UserID, "PAT 1 UserID should be cleared") + + pat2Retrieved, exists := u1.PATs[patID2] + require.True(t, exists, "PAT 2 should exist") + assert.Equal(t, "PAT 2", pat2Retrieved.Name, "PAT 2 name mismatch") + assert.Nil(t, pat2Retrieved.LastUsed, "PAT 2 last used should be nil") + + // Validate User 2 + u2, exists := retrievedAccount.Users[userID2] + require.True(t, exists, "User 2 should exist") + assert.Equal(t, types.UserRoleAdmin, u2.Role, "User 2 role mismatch") + assert.True(t, u2.IsServiceUser, "User 2 should be a service user") + assert.False(t, u2.NonDeletable, "User 2 should be deletable") + assert.Equal(t, types.UserIssuedIntegration, u2.Issued, "User 2 issued type mismatch") + assert.Equal(t, "google_workspace", u2.IntegrationReference.IntegrationType, "User 2 integration type mismatch") + + // Validate User 2 PATs + require.Len(t, u2.PATs, 1, "User 2 should have 1 PAT") + pat3Retrieved, exists := u2.PATs[patID3] + require.True(t, exists, "PAT 3 should exist") + assert.Equal(t, "PAT 3", pat3Retrieved.Name, "PAT 3 name mismatch") + }) + + // ========== VALIDATE GROUPS ========== + t.Run("Groups", func(t *testing.T) { + require.Len(t, retrievedAccount.Groups, 2, "Should have 2 groups") + + // Validate Group 1 + g1, exists := retrievedAccount.Groups[groupID1] + require.True(t, exists, "Group 1 should exist") + assert.Equal(t, "Group 1", g1.Name, "Group 1 name mismatch") + assert.Equal(t, types.GroupIssuedAPI, g1.Issued, "Group 1 issued type mismatch") + assert.Len(t, g1.Peers, 2, "Group 1 should have 2 peers") + assert.Contains(t, g1.Peers, peerID1, "Group 1 should contain peer 1") + assert.Contains(t, g1.Peers, peerID2, "Group 1 should contain peer 2") + + // Validate Group 1 Resources + assert.Len(t, g1.Resources, 1, "Group 1 should have 1 resource") + assert.Equal(t, "resource-1", g1.Resources[0].ID, "Group 1 resource ID mismatch") + assert.Equal(t, types.ResourceTypeHost, g1.Resources[0].Type, "Group 1 resource type mismatch") + + // Validate Group 2 + g2, exists := retrievedAccount.Groups[groupID2] + require.True(t, exists, "Group 2 should exist") + assert.Equal(t, "Group 2", g2.Name, "Group 2 name mismatch") + assert.Equal(t, types.GroupIssuedIntegration, g2.Issued, "Group 2 issued type mismatch") + assert.Len(t, g2.Peers, 1, "Group 2 should have 1 peer") + assert.Contains(t, g2.Peers, peerID3, "Group 2 should contain peer 3") + assert.Len(t, g2.Resources, 0, "Group 2 should have 0 resources") + + // Validate Group 2 Integration Reference + assert.Equal(t, 789, g2.IntegrationReference.ID, "Group 2 integration ID mismatch") + assert.Equal(t, "okta", g2.IntegrationReference.IntegrationType, "Group 2 integration type mismatch") + }) + + // ========== VALIDATE POLICIES ========== + t.Run("Policies", func(t *testing.T) { + require.Len(t, retrievedAccount.Policies, 2, "Should have 2 policies") + + // Validate Policy 1 + pol1 := retrievedAccount.Policies[0] + if pol1.ID != policyID1 { + pol1 = retrievedAccount.Policies[1] + } + assert.Equal(t, policyID1, pol1.ID, "Policy 1 ID mismatch") + assert.Equal(t, "Policy 1", pol1.Name, "Policy 1 name mismatch") + assert.Equal(t, "Main access policy", pol1.Description, "Policy 1 description mismatch") + assert.True(t, pol1.Enabled, "Policy 1 should be enabled") + + // Validate Policy 1 Rules + require.Len(t, pol1.Rules, 2, "Policy 1 should have 2 rules") + + rule1 := pol1.Rules[0] + assert.Equal(t, "Rule 1", rule1.Name, "Rule 1 name mismatch") + assert.Equal(t, "Allow access", rule1.Description, "Rule 1 description mismatch") + assert.True(t, rule1.Enabled, "Rule 1 should be enabled") + assert.Equal(t, types.PolicyTrafficActionAccept, rule1.Action, "Rule 1 action mismatch") + assert.True(t, rule1.Bidirectional, "Rule 1 should be bidirectional") + assert.Equal(t, types.PolicyRuleProtocolALL, rule1.Protocol, "Rule 1 protocol mismatch") + assert.Len(t, rule1.Sources, 1, "Rule 1 sources length mismatch") + assert.Contains(t, rule1.Sources, groupID1, "Rule 1 should have group1 as source") + assert.Len(t, rule1.Destinations, 1, "Rule 1 destinations length mismatch") + assert.Contains(t, rule1.Destinations, groupID2, "Rule 1 should have group2 as destination") + + rule2 := pol1.Rules[1] + assert.Equal(t, "Rule 2", rule2.Name, "Rule 2 name mismatch") + assert.Equal(t, types.PolicyTrafficActionDrop, rule2.Action, "Rule 2 action mismatch") + assert.False(t, rule2.Bidirectional, "Rule 2 should not be bidirectional") + assert.Equal(t, types.PolicyRuleProtocolTCP, rule2.Protocol, "Rule 2 protocol mismatch") + assert.Len(t, rule2.Ports, 2, "Rule 2 ports length mismatch") + assert.Contains(t, rule2.Ports, "22", "Rule 2 should have port 22") + assert.Contains(t, rule2.Ports, "3389", "Rule 2 should have port 3389") + assert.Len(t, rule2.PortRanges, 1, "Rule 2 port ranges length mismatch") + assert.Equal(t, uint16(8000), rule2.PortRanges[0].Start, "Rule 2 port range start mismatch") + assert.Equal(t, uint16(8999), rule2.PortRanges[0].End, "Rule 2 port range end mismatch") + + // Validate Policy 2 + pol2 := retrievedAccount.Policies[1] + if pol2.ID != policyID2 { + pol2 = retrievedAccount.Policies[0] + } + assert.Equal(t, policyID2, pol2.ID, "Policy 2 ID mismatch") + assert.Equal(t, "Policy 2", pol2.Name, "Policy 2 name mismatch") + assert.False(t, pol2.Enabled, "Policy 2 should be disabled") + require.Len(t, pol2.Rules, 1, "Policy 2 should have 1 rule") + + rule3 := pol2.Rules[0] + assert.Equal(t, "Rule 3", rule3.Name, "Rule 3 name mismatch") + assert.False(t, rule3.Enabled, "Rule 3 should be disabled") + assert.Equal(t, types.PolicyRuleProtocolUDP, rule3.Protocol, "Rule 3 protocol mismatch") + }) + + // ========== VALIDATE ROUTES ========== + t.Run("Routes", func(t *testing.T) { + require.Len(t, retrievedAccount.Routes, 2, "Should have 2 routes") + + // Validate Route 1 + r1, exists := retrievedAccount.Routes[routeID1] + require.True(t, exists, "Route 1 should exist") + assert.Equal(t, "Route 1", r1.Description, "Route 1 description mismatch") + assert.Equal(t, route.IPv4Network, r1.NetworkType, "Route 1 network type mismatch") + assert.Equal(t, peerID1, r1.Peer, "Route 1 peer mismatch") + assert.Empty(t, r1.PeerGroups, "Route 1 peer groups should be empty") + assert.Equal(t, route.NetID("net-id-1"), r1.NetID, "Route 1 net ID mismatch") + assert.True(t, r1.Masquerade, "Route 1 masquerade should be enabled") + assert.Equal(t, 9999, r1.Metric, "Route 1 metric mismatch") + assert.True(t, r1.Enabled, "Route 1 should be enabled") + assert.Len(t, r1.Groups, 1, "Route 1 groups length mismatch") + assert.Contains(t, r1.Groups, groupID1, "Route 1 should have group1") + assert.Len(t, r1.AccessControlGroups, 1, "Route 1 ACL groups length mismatch") + assert.Contains(t, r1.AccessControlGroups, groupID2, "Route 1 should have group2 in ACL") + + // Validate Route 1 Network CIDR + assert.Equal(t, "10.0.0.0/24", r1.Network.String(), "Route 1 network CIDR mismatch") + + // Validate Route 2 + r2, exists := retrievedAccount.Routes[routeID2] + require.True(t, exists, "Route 2 should exist") + assert.Equal(t, "Route 2 (High Availability)", r2.Description, "Route 2 description mismatch") + assert.Empty(t, r2.Peer, "Route 2 peer should be empty") + assert.Len(t, r2.PeerGroups, 1, "Route 2 peer groups length mismatch") + assert.Contains(t, r2.PeerGroups, groupID2, "Route 2 should have group2 as peer group") + assert.False(t, r2.Masquerade, "Route 2 masquerade should be disabled") + assert.Equal(t, 100, r2.Metric, "Route 2 metric mismatch") + assert.Equal(t, "192.168.1.0/24", r2.Network.String(), "Route 2 network CIDR mismatch") + }) + + // ========== VALIDATE NAME SERVER GROUPS ========== + t.Run("NameServerGroups", func(t *testing.T) { + require.Len(t, retrievedAccount.NameServerGroups, 2, "Should have 2 nameserver groups") + + // Validate NS Group 1 + nsg1, exists := retrievedAccount.NameServerGroups[nsGroupID1] + require.True(t, exists, "NS Group 1 should exist") + assert.Equal(t, "NS Group 1", nsg1.Name, "NS Group 1 name mismatch") + assert.Equal(t, "Primary nameservers", nsg1.Description, "NS Group 1 description mismatch") + assert.True(t, nsg1.Enabled, "NS Group 1 should be enabled") + assert.True(t, nsg1.Primary, "NS Group 1 should be primary") + assert.True(t, nsg1.SearchDomainsEnabled, "NS Group 1 search domains should be enabled") + assert.Empty(t, nsg1.AccountID, "NS Group 1 AccountID should be cleared") + + // Validate NS Group 1 NameServers + require.Len(t, nsg1.NameServers, 2, "NS Group 1 should have 2 nameservers") + assert.Equal(t, netip.MustParseAddr("8.8.8.8"), nsg1.NameServers[0].IP, "NS Group 1 nameserver 1 IP mismatch") + assert.Equal(t, nbdns.UDPNameServerType, nsg1.NameServers[0].NSType, "NS Group 1 nameserver 1 type mismatch") + assert.Equal(t, 53, nsg1.NameServers[0].Port, "NS Group 1 nameserver 1 port mismatch") + assert.Equal(t, netip.MustParseAddr("8.8.4.4"), nsg1.NameServers[1].IP, "NS Group 1 nameserver 2 IP mismatch") + + // Validate NS Group 1 Groups and Domains + assert.Len(t, nsg1.Groups, 2, "NS Group 1 groups length mismatch") + assert.Contains(t, nsg1.Groups, groupID1, "NS Group 1 should have group1") + assert.Contains(t, nsg1.Groups, groupID2, "NS Group 1 should have group2") + assert.Len(t, nsg1.Domains, 2, "NS Group 1 domains length mismatch") + assert.Contains(t, nsg1.Domains, "example.com", "NS Group 1 should have example.com domain") + assert.Contains(t, nsg1.Domains, "test.com", "NS Group 1 should have test.com domain") + + // Validate NS Group 2 + nsg2, exists := retrievedAccount.NameServerGroups[nsGroupID2] + require.True(t, exists, "NS Group 2 should exist") + assert.Equal(t, "NS Group 2", nsg2.Name, "NS Group 2 name mismatch") + assert.False(t, nsg2.Enabled, "NS Group 2 should be disabled") + assert.False(t, nsg2.Primary, "NS Group 2 should not be primary") + assert.False(t, nsg2.SearchDomainsEnabled, "NS Group 2 search domains should be disabled") + assert.Len(t, nsg2.NameServers, 1, "NS Group 2 should have 1 nameserver") + assert.Len(t, nsg2.Groups, 0, "NS Group 2 should have empty groups") + assert.Len(t, nsg2.Domains, 0, "NS Group 2 should have empty domains") + }) + + // ========== VALIDATE POSTURE CHECKS ========== + t.Run("PostureChecks", func(t *testing.T) { + require.Len(t, retrievedAccount.PostureChecks, 2, "Should have 2 posture checks") + + // Find posture checks by ID + var pc1, pc2 *posture.Checks + for _, pc := range retrievedAccount.PostureChecks { + if pc.ID == postureCheckID1 { + pc1 = pc + } else if pc.ID == postureCheckID2 { + pc2 = pc + } + } + + // Validate Posture Check 1 + require.NotNil(t, pc1, "Posture check 1 should exist") + assert.Equal(t, "Posture Check 1", pc1.Name, "Posture check 1 name mismatch") + assert.Equal(t, "OS version check", pc1.Description, "Posture check 1 description mismatch") + + // Validate NB Version Check + require.NotNil(t, pc1.Checks.NBVersionCheck, "NB version check should not be nil") + assert.Equal(t, "0.24.0", pc1.Checks.NBVersionCheck.MinVersion, "NB version check min version mismatch") + + // Validate OS Version Check + require.NotNil(t, pc1.Checks.OSVersionCheck, "OS version check should not be nil") + require.NotNil(t, pc1.Checks.OSVersionCheck.Ios, "iOS version check should not be nil") + assert.Equal(t, "16.0", pc1.Checks.OSVersionCheck.Ios.MinVersion, "iOS min version mismatch") + require.NotNil(t, pc1.Checks.OSVersionCheck.Darwin, "Darwin version check should not be nil") + assert.Equal(t, "22.0.0", pc1.Checks.OSVersionCheck.Darwin.MinVersion, "Darwin min version mismatch") + + // Validate Posture Check 2 + require.NotNil(t, pc2, "Posture check 2 should exist") + assert.Equal(t, "Posture Check 2", pc2.Name, "Posture check 2 name mismatch") + + // Validate Geo Location Check + require.NotNil(t, pc2.Checks.GeoLocationCheck, "Geo location check should not be nil") + assert.Equal(t, "allow", pc2.Checks.GeoLocationCheck.Action, "Geo location action mismatch") + assert.Len(t, pc2.Checks.GeoLocationCheck.Locations, 2, "Geo location check should have 2 locations") + assert.Equal(t, "US", pc2.Checks.GeoLocationCheck.Locations[0].CountryCode, "Location 1 country code mismatch") + assert.Equal(t, "San Francisco", pc2.Checks.GeoLocationCheck.Locations[0].CityName, "Location 1 city name mismatch") + assert.Equal(t, "GB", pc2.Checks.GeoLocationCheck.Locations[1].CountryCode, "Location 2 country code mismatch") + assert.Equal(t, "London", pc2.Checks.GeoLocationCheck.Locations[1].CityName, "Location 2 city name mismatch") + + // Validate Peer Network Range Check + require.NotNil(t, pc2.Checks.PeerNetworkRangeCheck, "Peer network range check should not be nil") + assert.Equal(t, "allow", pc2.Checks.PeerNetworkRangeCheck.Action, "Peer network range action mismatch") + assert.Len(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, 2, "Peer network range check should have 2 ranges") + assert.Contains(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, netip.MustParsePrefix("192.168.0.0/16"), "Should have 192.168.0.0/16 range") + assert.Contains(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, netip.MustParsePrefix("10.0.0.0/8"), "Should have 10.0.0.0/8 range") + }) + + // ========== VALIDATE NETWORKS ========== + t.Run("Networks", func(t *testing.T) { + require.Len(t, retrievedAccount.Networks, 1, "Should have 1 network") + + net1 := retrievedAccount.Networks[0] + assert.Equal(t, networkID1, net1.ID, "Network 1 ID mismatch") + assert.Equal(t, "Network 1", net1.Name, "Network 1 name mismatch") + assert.Equal(t, "Primary network", net1.Description, "Network 1 description mismatch") + }) + + // ========== VALIDATE NETWORK ROUTERS ========== + t.Run("NetworkRouters", func(t *testing.T) { + require.Len(t, retrievedAccount.NetworkRouters, 1, "Should have 1 network router") + + router := retrievedAccount.NetworkRouters[0] + assert.Equal(t, routerID1, router.ID, "Router 1 ID mismatch") + assert.Equal(t, networkID1, router.NetworkID, "Router 1 network ID mismatch") + assert.Equal(t, peerID1, router.Peer, "Router 1 peer mismatch") + assert.Empty(t, router.PeerGroups, "Router 1 peer groups should be empty") + assert.True(t, router.Masquerade, "Router 1 masquerade should be enabled") + assert.Equal(t, 100, router.Metric, "Router 1 metric mismatch") + }) + + // ========== VALIDATE NETWORK RESOURCES ========== + t.Run("NetworkResources", func(t *testing.T) { + require.Len(t, retrievedAccount.NetworkResources, 1, "Should have 1 network resource") + + res := retrievedAccount.NetworkResources[0] + assert.Equal(t, resourceID1, res.ID, "Resource 1 ID mismatch") + assert.Equal(t, networkID1, res.NetworkID, "Resource 1 network ID mismatch") + assert.Equal(t, "Resource 1", res.Name, "Resource 1 name mismatch") + assert.Equal(t, "Web server", res.Description, "Resource 1 description mismatch") + assert.Equal(t, netip.MustParsePrefix("192.168.1.100/32"), res.Prefix, "Resource 1 prefix mismatch") + assert.Equal(t, resourceTypes.Host, res.Type, "Resource 1 type mismatch") + }) + + // ========== VALIDATE ONBOARDING ========== + t.Run("Onboarding", func(t *testing.T) { + assert.Equal(t, accountID, retrievedAccount.Onboarding.AccountID, "Onboarding account ID mismatch") + assert.True(t, retrievedAccount.Onboarding.OnboardingFlowPending, "Onboarding flow should be pending") + assert.False(t, retrievedAccount.Onboarding.SignupFormPending, "Signup form should not be pending") + assert.WithinDuration(t, now, retrievedAccount.Onboarding.CreatedAt, time.Second, "Onboarding created at mismatch") + }) + + t.Log("✅ All comprehensive account field validations passed!") +} From 4896428d76338271aab760664315662913913a39 Mon Sep 17 00:00:00 2001 From: crn4 Date: Tue, 21 Oct 2025 14:07:27 +0200 Subject: [PATCH 12/14] changed GetAccount to sql raw version --- management/server/store/sql_store.go | 1008 +++++++++++++++-- .../server/store/sqlstore_bench_test.go | 688 +---------- 2 files changed, 937 insertions(+), 759 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 41d2af0c2..b5f4d8cbf 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2,6 +2,7 @@ package store import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -15,6 +16,7 @@ import ( "sync" "time" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" "gorm.io/driver/mysql" @@ -776,109 +778,971 @@ func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types. } func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { - start := time.Now() - defer func() { - elapsed := time.Since(start) - if elapsed > 1*time.Second { - log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + account, err := s.getAccount(ctx, accountID) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + errChan := make(chan error, 12) + + wg.Add(1) + go func() { + defer wg.Done() + keys, err := s.getSetupKeys(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + peers, err := s.getPeers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + users, err := s.getUsers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + groups, err := s.getGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + policies, err := s.getPolicies(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + routes, err := s.getRoutes(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + nsgs, err := s.getNameServerGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + checks, err := s.getPostureChecks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + networks, err := s.getNetworks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Networks = networks + }() + + wg.Add(1) + go func() { + defer wg.Done() + routers, err := s.getNetworkRouters(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + resources, err := s.getNetworkResources(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.getAccountOnboarding(ctx, accountID, account) + if err != nil { + errChan <- err + return } }() - var account types.Account - result := s.db.Model(&account). - Preload("UsersG.PATsG"). - Preload("Policies.Rules"). - Preload("SetupKeysG"). - Preload("PeersG"). - Preload("UsersG"). - Preload("GroupsG.GroupPeers"). - Preload("RoutesG"). - Preload("NameServerGroupsG"). - Preload("PostureChecks"). - Preload("Networks"). - Preload("NetworkRouters"). - Preload("NetworkResources"). - Preload("Onboarding"). - Take(&account, idQueryCondition, accountID) - if result.Error != nil { - log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.NewAccountNotFoundError(accountID) + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e } - return nil, status.NewGetAccountFromStoreError(result.Error) + } + + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + var err error + pats, err = s.getPersonalAccessTokens(ctx, userIDs) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + var err error + rules, err = s.getPolicyRules(ctx, policyIDs) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + var err error + groupPeers, err = s.getGroupPeers(ctx, groupIDs) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" + } + + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) + } + + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) } account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) - for _, key := range account.SetupKeysG { - if key.UpdatedAt.IsZero() { - key.UpdatedAt = key.CreatedAt - } - if key.AutoGroups == nil { - key.AutoGroups = []string{} - } - account.SetupKeys[key.Key] = &key + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key } - account.SetupKeysG = nil account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) - for _, peer := range account.PeersG { - account.Peers[peer.ID] = &peer + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer } - account.PeersG = nil account.Users = make(map[string]*types.User, len(account.UsersG)) - for _, user := range account.UsersG { - user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) - for _, pat := range user.PATsG { - pat.UserID = "" - user.PATs[pat.ID] = &pat + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat + } + } + account.Users[user.Id] = user + } + + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules } - if user.AutoGroups == nil { - user.AutoGroups = []string{} - } - account.Users[user.Id] = &user - user.PATsG = nil } - account.UsersG = nil account.Groups = make(map[string]*types.Group, len(account.GroupsG)) - for _, group := range account.GroupsG { - group.Peers = make([]string, len(group.GroupPeers)) - for i, gp := range group.GroupPeers { - group.Peers[i] = gp.PeerID - } - if group.Resources == nil { - group.Resources = []types.Resource{} + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs } account.Groups[group.ID] = group } - account.GroupsG = nil account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) - for _, route := range account.RoutesG { - account.Routes[route.ID] = &route + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route } - account.RoutesG = nil account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) - for _, ns := range account.NameServerGroupsG { - ns.AccountID = "" - if ns.NameServers == nil { - ns.NameServers = []nbdns.NameServer{} - } - if ns.Groups == nil { - ns.Groups = []string{} - } - if ns.Domains == nil { - ns.Domains = []string{} - } - account.NameServerGroups[ns.ID] = &ns + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil account.NameServerGroupsG = nil + return account, nil +} + +func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Account, error) { + var account types.Account + account.Network = &types.Network{} + const accountQuery = ` + SELECT + id, created_by, created_at, domain, domain_category, is_domain_primary_account, + -- Embedded Network + network_identifier, network_net, network_dns, network_serial, + -- Embedded DNSSettings + dns_settings_disabled_management_groups, + -- Embedded Settings + settings_peer_login_expiration_enabled, settings_peer_login_expiration, + settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration, + settings_regular_users_view_blocked, settings_groups_propagation_enabled, + settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups, + settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range, + settings_lazy_connection_enabled, + -- Embedded ExtraSettings + settings_extra_peer_approval_enabled, settings_extra_user_approval_required, + settings_extra_integrated_validator, settings_extra_integrated_validator_groups + FROM accounts WHERE id = $1` + + var networkNet, dnsSettingsDisabledGroups []byte + var ( + sPeerLoginExpirationEnabled sql.NullBool + sPeerLoginExpiration sql.NullInt64 + sPeerInactivityExpirationEnabled sql.NullBool + sPeerInactivityExpiration sql.NullInt64 + sRegularUsersViewBlocked sql.NullBool + sGroupsPropagationEnabled sql.NullBool + sJWTGroupsEnabled sql.NullBool + sJWTGroupsClaimName sql.NullString + sJWTAllowGroups []byte + sRoutingPeerDNSResolutionEnabled sql.NullBool + sDNSDomain sql.NullString + sNetworkRange []byte + sLazyConnectionEnabled sql.NullBool + sExtraPeerApprovalEnabled sql.NullBool + sExtraUserApprovalRequired sql.NullBool + sExtraIntegratedValidator sql.NullString + sExtraIntegratedValidatorGroups []byte + ) + + err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan( + &account.Id, &account.CreatedBy, &account.CreatedAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount, + &account.Network.Identifier, &networkNet, &account.Network.Dns, &account.Network.Serial, + &dnsSettingsDisabledGroups, + &sPeerLoginExpirationEnabled, &sPeerLoginExpiration, + &sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration, + &sRegularUsersViewBlocked, &sGroupsPropagationEnabled, + &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups, + &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange, + &sLazyConnectionEnabled, + &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired, + &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups, + ) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, errors.New("account not found") + } + return nil, err + } + + _ = json.Unmarshal(networkNet, &account.Network.Net) + _ = json.Unmarshal(dnsSettingsDisabledGroups, &account.DNSSettings.DisabledManagementGroups) + + account.Settings = &types.Settings{Extra: &types.ExtraSettings{}} + if sPeerLoginExpirationEnabled.Valid { + account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool + } + if sPeerLoginExpiration.Valid { + account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64) + } + if sPeerInactivityExpirationEnabled.Valid { + account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool + } + if sPeerInactivityExpiration.Valid { + account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64) + } + if sRegularUsersViewBlocked.Valid { + account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool + } + if sGroupsPropagationEnabled.Valid { + account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool + } + if sJWTGroupsEnabled.Valid { + account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool + } + if sJWTGroupsClaimName.Valid { + account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String + } + if sRoutingPeerDNSResolutionEnabled.Valid { + account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool + } + if sDNSDomain.Valid { + account.Settings.DNSDomain = sDNSDomain.String + } + if sLazyConnectionEnabled.Valid { + account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool + } + if sJWTAllowGroups != nil { + _ = json.Unmarshal(sJWTAllowGroups, &account.Settings.JWTAllowGroups) + } + if sNetworkRange != nil { + _ = json.Unmarshal(sNetworkRange, &account.Settings.NetworkRange) + } + + if sExtraPeerApprovalEnabled.Valid { + account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool + } + if sExtraUserApprovalRequired.Valid { + account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool + } + if sExtraIntegratedValidator.Valid { + account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String + } + if sExtraIntegratedValidatorGroups != nil { + _ = json.Unmarshal(sExtraIntegratedValidatorGroups, &account.Settings.Extra.IntegratedValidatorGroups) + } return &account, nil } +func (s *SqlStore) getSetupKeys(ctx context.Context, accountID string) ([]types.SetupKey, error) { + const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { + var sk types.SetupKey + var autoGroups []byte + var expiresAt, updatedAt, lastUsed sql.NullTime + var revoked, ephemeral, allowExtraDNSLabels sql.NullBool + var usedTimes, usageLimit sql.NullInt64 + + err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels) + + if err == nil { + if expiresAt.Valid { + sk.ExpiresAt = &expiresAt.Time + } + if updatedAt.Valid { + sk.UpdatedAt = updatedAt.Time + if sk.UpdatedAt.IsZero() { + sk.UpdatedAt = sk.CreatedAt + } + } + if lastUsed.Valid { + sk.LastUsed = &lastUsed.Time + } + if revoked.Valid { + sk.Revoked = revoked.Bool + } + if usedTimes.Valid { + sk.UsedTimes = int(usedTimes.Int64) + } + if usageLimit.Valid { + sk.UsageLimit = int(usageLimit.Int64) + } + if ephemeral.Valid { + sk.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &sk.AutoGroups) + } else { + sk.AutoGroups = []string{} + } + } + return sk, err + }) + if err != nil { + return nil, err + } + return keys, nil +} + +func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Peer, error) { + const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, location_geo_name_id FROM peers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { + var p nbpeer.Peer + p.Status = &nbpeer.PeerStatus{} + var lastLogin sql.NullTime + var sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool + var peerStatusLastSeen sql.NullTime + var peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool + var ip, extraDNS, netAddr, env, flags, files, connIP []byte + + err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &p.CreatedAt, &ephemeral, &extraDNS, &allowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) + + if err == nil { + if lastLogin.Valid { + p.LastLogin = &lastLogin.Time + } + if sshEnabled.Valid { + p.SSHEnabled = sshEnabled.Bool + } + if loginExpirationEnabled.Valid { + p.LoginExpirationEnabled = loginExpirationEnabled.Bool + } + if inactivityExpirationEnabled.Valid { + p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool + } + if ephemeral.Valid { + p.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if peerStatusLastSeen.Valid { + p.Status.LastSeen = peerStatusLastSeen.Time + } + if peerStatusConnected.Valid { + p.Status.Connected = peerStatusConnected.Bool + } + if peerStatusLoginExpired.Valid { + p.Status.LoginExpired = peerStatusLoginExpired.Bool + } + if peerStatusRequiresApproval.Valid { + p.Status.RequiresApproval = peerStatusRequiresApproval.Bool + } + + if ip != nil { + _ = json.Unmarshal(ip, &p.IP) + } + if extraDNS != nil { + _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) + } + if netAddr != nil { + _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses) + } + if env != nil { + _ = json.Unmarshal(env, &p.Meta.Environment) + } + if flags != nil { + _ = json.Unmarshal(flags, &p.Meta.Flags) + } + if files != nil { + _ = json.Unmarshal(files, &p.Meta.Files) + } + if connIP != nil { + _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) + } + } + return p, err + }) + if err != nil { + return nil, err + } + return peers, nil +} + +func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) { + const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { + var u types.User + var autoGroups []byte + var lastLogin sql.NullTime + var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool + err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) + if err == nil { + if lastLogin.Valid { + u.LastLogin = &lastLogin.Time + } + if isServiceUser.Valid { + u.IsServiceUser = isServiceUser.Bool + } + if nonDeletable.Valid { + u.NonDeletable = nonDeletable.Bool + } + if blocked.Valid { + u.Blocked = blocked.Bool + } + if pendingApproval.Valid { + u.PendingApproval = pendingApproval.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &u.AutoGroups) + } else { + u.AutoGroups = []string{} + } + } + return u, err + }) + if err != nil { + return nil, err + } + return users, nil +} + +func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) { + const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { + var g types.Group + var resources []byte + var refID sql.NullInt64 + var refType sql.NullString + err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) + if err == nil { + if refID.Valid { + g.IntegrationReference.ID = int(refID.Int64) + } + if refType.Valid { + g.IntegrationReference.IntegrationType = refType.String + } + if resources != nil { + _ = json.Unmarshal(resources, &g.Resources) + } else { + g.Resources = []types.Resource{} + } + g.GroupPeers = []types.GroupPeer{} + g.Peers = []string{} + } + return &g, err + }) + if err != nil { + return nil, err + } + return groups, nil +} + +func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) { + const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { + var p types.Policy + var checks []byte + var enabled sql.NullBool + err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) + if err == nil { + if enabled.Valid { + p.Enabled = enabled.Bool + } + if checks != nil { + _ = json.Unmarshal(checks, &p.SourcePostureChecks) + } + } + return &p, err + }) + if err != nil { + return nil, err + } + return policies, nil +} + +func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) { + const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { + var r route.Route + var network, domains, peerGroups, groups, accessGroups []byte + var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) + if err == nil { + if keepRoute.Valid { + r.KeepRoute = keepRoute.Bool + } + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if skipAutoApply.Valid { + r.SkipAutoApply = skipAutoApply.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if network != nil { + _ = json.Unmarshal(network, &r.Network) + } + if domains != nil { + _ = json.Unmarshal(domains, &r.Domains) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + if groups != nil { + _ = json.Unmarshal(groups, &r.Groups) + } + if accessGroups != nil { + _ = json.Unmarshal(accessGroups, &r.AccessControlGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + return routes, nil +} + +func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) { + const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { + var n nbdns.NameServerGroup + var ns, groups, domains []byte + var primary, enabled, searchDomainsEnabled sql.NullBool + err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) + if err == nil { + if primary.Valid { + n.Primary = primary.Bool + } + if enabled.Valid { + n.Enabled = enabled.Bool + } + if searchDomainsEnabled.Valid { + n.SearchDomainsEnabled = searchDomainsEnabled.Bool + } + if ns != nil { + _ = json.Unmarshal(ns, &n.NameServers) + } else { + n.NameServers = []nbdns.NameServer{} + } + if groups != nil { + _ = json.Unmarshal(groups, &n.Groups) + } else { + n.Groups = []string{} + } + if domains != nil { + _ = json.Unmarshal(domains, &n.Domains) + } else { + n.Domains = []string{} + } + } + return n, err + }) + if err != nil { + return nil, err + } + return nsgs, nil +} + +func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) { + const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { + var c posture.Checks + var checksDef []byte + err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) + if err == nil && checksDef != nil { + _ = json.Unmarshal(checksDef, &c.Checks) + } + return &c, err + }) + if err != nil { + return nil, err + } + return checks, nil +} + +func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) { + const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network]) + if err != nil { + return nil, err + } + result := make([]*networkTypes.Network, len(networks)) + for i := range networks { + result[i] = &networks[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) { + const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { + var r routerTypes.NetworkRouter + var peerGroups []byte + var masquerade, enabled sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) + if err == nil { + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*routerTypes.NetworkRouter, len(routers)) + for i := range routers { + result[i] = &routers[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) { + const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { + var r resourceTypes.NetworkResource + var prefix []byte + var enabled sql.NullBool + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if prefix != nil { + _ = json.Unmarshal(prefix, &r.Prefix) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*resourceTypes.NetworkResource, len(resources)) + for i := range resources { + result[i] = &resources[i] + } + return result, nil +} + +func (s *SqlStore) getAccountOnboarding(ctx context.Context, accountID string, account *types.Account) error { + const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` + var onboardingFlowPending, signupFormPending sql.NullBool + err := s.pool.QueryRow(ctx, query, accountID).Scan( + &account.Onboarding.AccountID, + &onboardingFlowPending, + &signupFormPending, + &account.Onboarding.CreatedAt, + &account.Onboarding.UpdatedAt, + ) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return err + } + if onboardingFlowPending.Valid { + account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool + } + if signupFormPending.Valid { + account.Onboarding.SignupFormPending = signupFormPending.Bool + } + return nil +} + +func (s *SqlStore) getPersonalAccessTokens(ctx context.Context, userIDs []string) ([]types.PersonalAccessToken, error) { + if len(userIDs) == 0 { + return nil, nil + } + const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, userIDs) + if err != nil { + return nil, err + } + pats, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { + var pat types.PersonalAccessToken + var expirationDate, lastUsed sql.NullTime + err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &pat.CreatedAt, &lastUsed) + if err == nil { + if expirationDate.Valid { + pat.ExpirationDate = &expirationDate.Time + } + if lastUsed.Valid { + pat.LastUsed = &lastUsed.Time + } + } + return pat, err + }) + if err != nil { + return nil, err + } + return pats, nil +} + +func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*types.PolicyRule, error) { + if len(policyIDs) == 0 { + return nil, nil + } + const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, policyIDs) + if err != nil { + return nil, err + } + rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { + var r types.PolicyRule + var dest, destRes, sources, sourceRes, ports, portRanges []byte + var enabled, bidirectional sql.NullBool + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if bidirectional.Valid { + r.Bidirectional = bidirectional.Bool + } + if dest != nil { + _ = json.Unmarshal(dest, &r.Destinations) + } + if destRes != nil { + _ = json.Unmarshal(destRes, &r.DestinationResource) + } + if sources != nil { + _ = json.Unmarshal(sources, &r.Sources) + } + if sourceRes != nil { + _ = json.Unmarshal(sourceRes, &r.SourceResource) + } + if ports != nil { + _ = json.Unmarshal(ports, &r.Ports) + } + if portRanges != nil { + _ = json.Unmarshal(portRanges, &r.PortRanges) + } + } + return &r, err + }) + if err != nil { + return nil, err + } + return rules, nil +} + +func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]types.GroupPeer, error) { + if len(groupIDs) == 0 { + return nil, nil + } + const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, groupIDs) + if err != nil { + return nil, err + } + groupPeers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer]) + if err != nil { + return nil, err + } + return groupPeers, nil +} + func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) { var user types.User result := s.db.Select("account_id").Take(&user, idQueryCondition, userID) diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index 8c94f53c4..b131c69ac 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -2,8 +2,6 @@ package store import ( "context" - "database/sql" - "encoding/json" "errors" "fmt" "net" @@ -17,7 +15,6 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -126,7 +123,7 @@ func connectDBforTest(ctx context.Context, dsn string) (*pgxpool.Pool, error) { return nil, fmt.Errorf("unable to parse database config: %w", err) } - config.MaxConns = 10 + config.MaxConns = 12 config.MinConns = 2 config.MaxConnLifetime = time.Hour config.HealthCheckPeriod = time.Minute @@ -844,686 +841,3 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty return account, nil } - -func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Account, error) { - var account types.Account - account.Network = &types.Network{} - const accountQuery = ` - SELECT - id, created_by, created_at, domain, domain_category, is_domain_primary_account, - -- Embedded Network - network_identifier, network_net, network_dns, network_serial, - -- Embedded DNSSettings - dns_settings_disabled_management_groups, - -- Embedded Settings - settings_peer_login_expiration_enabled, settings_peer_login_expiration, - settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration, - settings_regular_users_view_blocked, settings_groups_propagation_enabled, - settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups, - settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range, - settings_lazy_connection_enabled, - -- Embedded ExtraSettings - settings_extra_peer_approval_enabled, settings_extra_user_approval_required, - settings_extra_integrated_validator, settings_extra_integrated_validator_groups - FROM accounts WHERE id = $1` - - var networkNet, dnsSettingsDisabledGroups []byte - var ( - sPeerLoginExpirationEnabled sql.NullBool - sPeerLoginExpiration sql.NullInt64 - sPeerInactivityExpirationEnabled sql.NullBool - sPeerInactivityExpiration sql.NullInt64 - sRegularUsersViewBlocked sql.NullBool - sGroupsPropagationEnabled sql.NullBool - sJWTGroupsEnabled sql.NullBool - sJWTGroupsClaimName sql.NullString - sJWTAllowGroups []byte - sRoutingPeerDNSResolutionEnabled sql.NullBool - sDNSDomain sql.NullString - sNetworkRange []byte - sLazyConnectionEnabled sql.NullBool - sExtraPeerApprovalEnabled sql.NullBool - sExtraUserApprovalRequired sql.NullBool - sExtraIntegratedValidator sql.NullString - sExtraIntegratedValidatorGroups []byte - ) - - err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan( - &account.Id, &account.CreatedBy, &account.CreatedAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount, - &account.Network.Identifier, &networkNet, &account.Network.Dns, &account.Network.Serial, - &dnsSettingsDisabledGroups, - &sPeerLoginExpirationEnabled, &sPeerLoginExpiration, - &sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration, - &sRegularUsersViewBlocked, &sGroupsPropagationEnabled, - &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups, - &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange, - &sLazyConnectionEnabled, - &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired, - &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups, - ) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return nil, errors.New("account not found") - } - return nil, err - } - - _ = json.Unmarshal(networkNet, &account.Network.Net) - _ = json.Unmarshal(dnsSettingsDisabledGroups, &account.DNSSettings.DisabledManagementGroups) - - account.Settings = &types.Settings{Extra: &types.ExtraSettings{}} - if sPeerLoginExpirationEnabled.Valid { - account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool - } - if sPeerLoginExpiration.Valid { - account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64) - } - if sPeerInactivityExpirationEnabled.Valid { - account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool - } - if sPeerInactivityExpiration.Valid { - account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64) - } - if sRegularUsersViewBlocked.Valid { - account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool - } - if sGroupsPropagationEnabled.Valid { - account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool - } - if sJWTGroupsEnabled.Valid { - account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool - } - if sJWTGroupsClaimName.Valid { - account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String - } - if sRoutingPeerDNSResolutionEnabled.Valid { - account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool - } - if sDNSDomain.Valid { - account.Settings.DNSDomain = sDNSDomain.String - } - if sLazyConnectionEnabled.Valid { - account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool - } - if sJWTAllowGroups != nil { - _ = json.Unmarshal(sJWTAllowGroups, &account.Settings.JWTAllowGroups) - } - if sNetworkRange != nil { - _ = json.Unmarshal(sNetworkRange, &account.Settings.NetworkRange) - } - - if sExtraPeerApprovalEnabled.Valid { - account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool - } - if sExtraUserApprovalRequired.Valid { - account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool - } - if sExtraIntegratedValidator.Valid { - account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String - } - if sExtraIntegratedValidatorGroups != nil { - _ = json.Unmarshal(sExtraIntegratedValidatorGroups, &account.Settings.Extra.IntegratedValidatorGroups) - } - return &account, nil -} - -func (s *SqlStore) getSetupKeys(ctx context.Context, accountID string) ([]types.SetupKey, error) { - const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - return nil, err - } - - keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { - var sk types.SetupKey - var autoGroups []byte - var expiresAt, updatedAt, lastUsed sql.NullTime - var revoked, ephemeral, allowExtraDNSLabels sql.NullBool - var usedTimes, usageLimit sql.NullInt64 - - err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels) - - if err == nil { - if expiresAt.Valid { - sk.ExpiresAt = &expiresAt.Time - } - if updatedAt.Valid { - sk.UpdatedAt = updatedAt.Time - if sk.UpdatedAt.IsZero() { - sk.UpdatedAt = sk.CreatedAt - } - } - if lastUsed.Valid { - sk.LastUsed = &lastUsed.Time - } - if revoked.Valid { - sk.Revoked = revoked.Bool - } - if usedTimes.Valid { - sk.UsedTimes = int(usedTimes.Int64) - } - if usageLimit.Valid { - sk.UsageLimit = int(usageLimit.Int64) - } - if ephemeral.Valid { - sk.Ephemeral = ephemeral.Bool - } - if allowExtraDNSLabels.Valid { - sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool - } - if autoGroups != nil { - _ = json.Unmarshal(autoGroups, &sk.AutoGroups) - } else { - sk.AutoGroups = []string{} - } - } - return sk, err - }) - if err != nil { - return nil, err - } - return keys, nil -} - -func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Peer, error) { - const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, location_geo_name_id FROM peers WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - return nil, err - } - - peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { - var p nbpeer.Peer - p.Status = &nbpeer.PeerStatus{} - var lastLogin sql.NullTime - var sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool - var peerStatusLastSeen sql.NullTime - var peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool - var ip, extraDNS, netAddr, env, flags, files, connIP []byte - - err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &p.CreatedAt, &ephemeral, &extraDNS, &allowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) - - if err == nil { - if lastLogin.Valid { - p.LastLogin = &lastLogin.Time - } - if sshEnabled.Valid { - p.SSHEnabled = sshEnabled.Bool - } - if loginExpirationEnabled.Valid { - p.LoginExpirationEnabled = loginExpirationEnabled.Bool - } - if inactivityExpirationEnabled.Valid { - p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool - } - if ephemeral.Valid { - p.Ephemeral = ephemeral.Bool - } - if allowExtraDNSLabels.Valid { - p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool - } - if peerStatusLastSeen.Valid { - p.Status.LastSeen = peerStatusLastSeen.Time - } - if peerStatusConnected.Valid { - p.Status.Connected = peerStatusConnected.Bool - } - if peerStatusLoginExpired.Valid { - p.Status.LoginExpired = peerStatusLoginExpired.Bool - } - if peerStatusRequiresApproval.Valid { - p.Status.RequiresApproval = peerStatusRequiresApproval.Bool - } - - if ip != nil { - _ = json.Unmarshal(ip, &p.IP) - } - if extraDNS != nil { - _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) - } - if netAddr != nil { - _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses) - } - if env != nil { - _ = json.Unmarshal(env, &p.Meta.Environment) - } - if flags != nil { - _ = json.Unmarshal(flags, &p.Meta.Flags) - } - if files != nil { - _ = json.Unmarshal(files, &p.Meta.Files) - } - if connIP != nil { - _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) - } - } - return p, err - }) - if err != nil { - return nil, err - } - return peers, nil -} - -func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) { - const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - return nil, err - } - users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { - var u types.User - var autoGroups []byte - var lastLogin sql.NullTime - var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool - err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) - if err == nil { - if lastLogin.Valid { - u.LastLogin = &lastLogin.Time - } - if isServiceUser.Valid { - u.IsServiceUser = isServiceUser.Bool - } - if nonDeletable.Valid { - u.NonDeletable = nonDeletable.Bool - } - if blocked.Valid { - u.Blocked = blocked.Bool - } - if pendingApproval.Valid { - u.PendingApproval = pendingApproval.Bool - } - if autoGroups != nil { - _ = json.Unmarshal(autoGroups, &u.AutoGroups) - } else { - u.AutoGroups = []string{} - } - } - return u, err - }) - if err != nil { - return nil, err - } - return users, nil -} - -func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) { - const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - return nil, err - } - groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { - var g types.Group - var resources []byte - var refID sql.NullInt64 - var refType sql.NullString - err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) - if err == nil { - if refID.Valid { - g.IntegrationReference.ID = int(refID.Int64) - } - if refType.Valid { - g.IntegrationReference.IntegrationType = refType.String - } - if resources != nil { - _ = json.Unmarshal(resources, &g.Resources) - } else { - g.Resources = []types.Resource{} - } - g.GroupPeers = []types.GroupPeer{} - g.Peers = []string{} - } - return &g, err - }) - if err != nil { - return nil, err - } - return groups, nil -} - -func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) { - const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - return nil, err - } - policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { - var p types.Policy - var checks []byte - var enabled sql.NullBool - err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) - if err == nil { - if enabled.Valid { - p.Enabled = enabled.Bool - } - if checks != nil { - _ = json.Unmarshal(checks, &p.SourcePostureChecks) - } - } - return &p, err - }) - if err != nil { - return nil, err - } - return policies, nil -} - -func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) { - const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - return nil, err - } - routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { - var r route.Route - var network, domains, peerGroups, groups, accessGroups []byte - var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool - var metric sql.NullInt64 - err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) - if err == nil { - if keepRoute.Valid { - r.KeepRoute = keepRoute.Bool - } - if masquerade.Valid { - r.Masquerade = masquerade.Bool - } - if enabled.Valid { - r.Enabled = enabled.Bool - } - if skipAutoApply.Valid { - r.SkipAutoApply = skipAutoApply.Bool - } - if metric.Valid { - r.Metric = int(metric.Int64) - } - if network != nil { - _ = json.Unmarshal(network, &r.Network) - } - if domains != nil { - _ = json.Unmarshal(domains, &r.Domains) - } - if peerGroups != nil { - _ = json.Unmarshal(peerGroups, &r.PeerGroups) - } - if groups != nil { - _ = json.Unmarshal(groups, &r.Groups) - } - if accessGroups != nil { - _ = json.Unmarshal(accessGroups, &r.AccessControlGroups) - } - } - return r, err - }) - if err != nil { - return nil, err - } - return routes, nil -} - -func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) { - const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - return nil, err - } - nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { - var n nbdns.NameServerGroup - var ns, groups, domains []byte - var primary, enabled, searchDomainsEnabled sql.NullBool - err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) - if err == nil { - if primary.Valid { - n.Primary = primary.Bool - } - if enabled.Valid { - n.Enabled = enabled.Bool - } - if searchDomainsEnabled.Valid { - n.SearchDomainsEnabled = searchDomainsEnabled.Bool - } - if ns != nil { - _ = json.Unmarshal(ns, &n.NameServers) - } else { - n.NameServers = []nbdns.NameServer{} - } - if groups != nil { - _ = json.Unmarshal(groups, &n.Groups) - } else { - n.Groups = []string{} - } - if domains != nil { - _ = json.Unmarshal(domains, &n.Domains) - } else { - n.Domains = []string{} - } - } - return n, err - }) - if err != nil { - return nil, err - } - return nsgs, nil -} - -func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) { - const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - return nil, err - } - checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { - var c posture.Checks - var checksDef []byte - err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) - if err == nil && checksDef != nil { - _ = json.Unmarshal(checksDef, &c.Checks) - } - return &c, err - }) - if err != nil { - return nil, err - } - return checks, nil -} - -func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) { - const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - return nil, err - } - networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network]) - if err != nil { - return nil, err - } - result := make([]*networkTypes.Network, len(networks)) - for i := range networks { - result[i] = &networks[i] - } - return result, nil -} - -func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) { - const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - return nil, err - } - routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { - var r routerTypes.NetworkRouter - var peerGroups []byte - var masquerade, enabled sql.NullBool - var metric sql.NullInt64 - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) - if err == nil { - if masquerade.Valid { - r.Masquerade = masquerade.Bool - } - if enabled.Valid { - r.Enabled = enabled.Bool - } - if metric.Valid { - r.Metric = int(metric.Int64) - } - if peerGroups != nil { - _ = json.Unmarshal(peerGroups, &r.PeerGroups) - } - } - return r, err - }) - if err != nil { - return nil, err - } - result := make([]*routerTypes.NetworkRouter, len(routers)) - for i := range routers { - result[i] = &routers[i] - } - return result, nil -} - -func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) { - const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - return nil, err - } - resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { - var r resourceTypes.NetworkResource - var prefix []byte - var enabled sql.NullBool - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) - if err == nil { - if enabled.Valid { - r.Enabled = enabled.Bool - } - if prefix != nil { - _ = json.Unmarshal(prefix, &r.Prefix) - } - } - return r, err - }) - if err != nil { - return nil, err - } - result := make([]*resourceTypes.NetworkResource, len(resources)) - for i := range resources { - result[i] = &resources[i] - } - return result, nil -} - -func (s *SqlStore) getAccountOnboarding(ctx context.Context, accountID string, account *types.Account) error { - const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` - var onboardingFlowPending, signupFormPending sql.NullBool - err := s.pool.QueryRow(ctx, query, accountID).Scan( - &account.Onboarding.AccountID, - &onboardingFlowPending, - &signupFormPending, - &account.Onboarding.CreatedAt, - &account.Onboarding.UpdatedAt, - ) - if err != nil && !errors.Is(err, pgx.ErrNoRows) { - return err - } - if onboardingFlowPending.Valid { - account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool - } - if signupFormPending.Valid { - account.Onboarding.SignupFormPending = signupFormPending.Bool - } - return nil -} - -func (s *SqlStore) getPersonalAccessTokens(ctx context.Context, userIDs []string) ([]types.PersonalAccessToken, error) { - if len(userIDs) == 0 { - return nil, nil - } - const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)` - rows, err := s.pool.Query(ctx, query, userIDs) - if err != nil { - return nil, err - } - pats, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { - var pat types.PersonalAccessToken - var expirationDate, lastUsed sql.NullTime - err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &pat.CreatedAt, &lastUsed) - if err == nil { - if expirationDate.Valid { - pat.ExpirationDate = &expirationDate.Time - } - if lastUsed.Valid { - pat.LastUsed = &lastUsed.Time - } - } - return pat, err - }) - if err != nil { - return nil, err - } - return pats, nil -} - -func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*types.PolicyRule, error) { - if len(policyIDs) == 0 { - return nil, nil - } - const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` - rows, err := s.pool.Query(ctx, query, policyIDs) - if err != nil { - return nil, err - } - rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { - var r types.PolicyRule - var dest, destRes, sources, sourceRes, ports, portRanges []byte - var enabled, bidirectional sql.NullBool - err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) - if err == nil { - if enabled.Valid { - r.Enabled = enabled.Bool - } - if bidirectional.Valid { - r.Bidirectional = bidirectional.Bool - } - if dest != nil { - _ = json.Unmarshal(dest, &r.Destinations) - } - if destRes != nil { - _ = json.Unmarshal(destRes, &r.DestinationResource) - } - if sources != nil { - _ = json.Unmarshal(sources, &r.Sources) - } - if sourceRes != nil { - _ = json.Unmarshal(sourceRes, &r.SourceResource) - } - if ports != nil { - _ = json.Unmarshal(ports, &r.Ports) - } - if portRanges != nil { - _ = json.Unmarshal(portRanges, &r.PortRanges) - } - } - return &r, err - }) - if err != nil { - return nil, err - } - return rules, nil -} - -func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]types.GroupPeer, error) { - if len(groupIDs) == 0 { - return nil, nil - } - const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)` - rows, err := s.pool.Query(ctx, query, groupIDs) - if err != nil { - return nil, err - } - groupPeers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer]) - if err != nil { - return nil, err - } - return groupPeers, nil -} From ab8a2baa32f5262796768955e028bbde8d337b63 Mon Sep 17 00:00:00 2001 From: crn4 Date: Thu, 30 Oct 2025 18:03:06 +0100 Subject: [PATCH 13/14] minor cleanup changes --- management/server/store/sql_store.go | 14 +++++++++----- management/server/store/sqlstore_bench_test.go | 2 +- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index b5f4d8cbf..1e698d2ae 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -49,6 +49,11 @@ const ( accountAndIDsQueryCondition = "account_id = ? AND id IN ?" accountIDCondition = "account_id = ?" peerNotFoundFMT = "peer %s not found" + + pgMaxConnections = 30 + pgMinConnections = 5 + pgMaxConnLifetime = 60 * time.Minute + pgHealthCheckPeriod = 1 * time.Minute ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -2101,10 +2106,10 @@ func connectToPgDb(ctx context.Context, dsn string) (*pgxpool.Pool, error) { return nil, fmt.Errorf("unable to parse database config: %w", err) } - config.MaxConns = 10 - config.MinConns = 2 - config.MaxConnLifetime = time.Hour - config.HealthCheckPeriod = time.Minute + config.MaxConns = pgMaxConnections + config.MinConns = pgMinConnections + config.MaxConnLifetime = pgMaxConnLifetime + config.HealthCheckPeriod = pgHealthCheckPeriod pool, err := pgxpool.NewWithConfig(ctx, config) if err != nil { @@ -2116,7 +2121,6 @@ func connectToPgDb(ctx context.Context, dsn string) (*pgxpool.Pool, error) { return nil, fmt.Errorf("unable to ping database: %w", err) } - fmt.Println("Successfully connected to the database!") return pool, nil } diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index b131c69ac..74bdb83b4 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -42,7 +42,7 @@ func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types var account types.Account result := s.db.Model(&account). Omit("GroupsG"). - Preload("UsersG.PATsG"). // have to be specifies as this is nester reference + Preload("UsersG.PATsG"). // have to be specified as this is nested reference Preload(clause.Associations). Take(&account, idQueryCondition, accountID) if result.Error != nil { From 435a342a367f339866ac1f8ef66dc9141811c8a7 Mon Sep 17 00:00:00 2001 From: crn4 Date: Thu, 30 Oct 2025 18:40:22 +0100 Subject: [PATCH 14/14] GetAccount method selection based on pg pool --- management/server/store/sql_store.go | 107 +++++++++++++++++ .../server/store/sqlstore_bench_test.go | 110 +++++++++++++++++- 2 files changed, 212 insertions(+), 5 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 1e698d2ae..489c14702 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -783,6 +783,113 @@ func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types. } func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { + if s.pool != nil { + return s.getAccountPgx(ctx, accountID) + } + return s.getAccountGorm(ctx, accountID) +} + +func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types.Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() + + var account types.Account + result := s.db.Model(&account). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload("Policies.Rules"). + Preload("SetupKeysG"). + Preload("PeersG"). + Preload("UsersG"). + Preload("GroupsG.GroupPeers"). + Preload("RoutesG"). + Preload("NameServerGroupsG"). + Preload("PostureChecks"). + Preload("Networks"). + Preload("NetworkRouters"). + Preload("NetworkResources"). + Preload("Onboarding"). + Take(&account, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } + account.SetupKeys[key.Key] = &key + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = &peer + } + account.PeersG = nil + account.Users = make(map[string]*types.User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + pat.UserID = "" + user.PATs[pat.ID] = &pat + } + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } + account.Users[user.Id] = &user + user.PATsG = nil + } + account.UsersG = nil + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + group.Peers = make([]string, len(group.GroupPeers)) + for i, gp := range group.GroupPeers { + group.Peers[i] = gp.PeerID + } + if group.Resources == nil { + group.Resources = []types.Resource{} + } + account.Groups[group.ID] = group + } + account.GroupsG = nil + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = &route + } + account.RoutesG = nil + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } + account.NameServerGroups[ns.ID] = &ns + } + account.NameServerGroupsG = nil + return &account, nil +} + +func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.Account, error) { account, err := s.getAccount(ctx, accountID) if err != nil { return nil, err diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index 74bdb83b4..10c6385ed 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -117,6 +117,106 @@ func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types return &account, nil } +func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*types.Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() + + var account types.Account + result := s.db.Model(&account). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload("Policies.Rules"). + Preload("SetupKeysG"). + Preload("PeersG"). + Preload("UsersG"). + Preload("GroupsG.GroupPeers"). + Preload("RoutesG"). + Preload("NameServerGroupsG"). + Preload("PostureChecks"). + Preload("Networks"). + Preload("NetworkRouters"). + Preload("NetworkResources"). + Preload("Onboarding"). + Take(&account, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } + account.SetupKeys[key.Key] = &key + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = &peer + } + account.PeersG = nil + account.Users = make(map[string]*types.User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + pat.UserID = "" + user.PATs[pat.ID] = &pat + } + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } + account.Users[user.Id] = &user + user.PATsG = nil + } + account.UsersG = nil + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + group.Peers = make([]string, len(group.GroupPeers)) + for i, gp := range group.GroupPeers { + group.Peers[i] = gp.PeerID + } + if group.Resources == nil { + group.Resources = []types.Resource{} + } + account.Groups[group.ID] = group + } + account.GroupsG = nil + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = &route + } + account.RoutesG = nil + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } + account.NameServerGroups[ns.ID] = &ns + } + account.NameServerGroupsG = nil + return &account, nil +} + func connectDBforTest(ctx context.Context, dsn string) (*pgxpool.Pool, error) { config, err := pgxpool.ParseConfig(dsn) if err != nil { @@ -400,9 +500,9 @@ func BenchmarkGetAccount(b *testing.B) { } } }) - b.Run("new", func(b *testing.B) { + b.Run("gorm opt", func(b *testing.B) { for range b.N { - _, err := store.GetAccount(ctx, accountID) + _, err := store.GetAccountGormOpt(ctx, accountID) if err != nil { b.Fatalf("GetAccountFast failed: %v", err) } @@ -410,7 +510,7 @@ func BenchmarkGetAccount(b *testing.B) { }) b.Run("raw", func(b *testing.B) { for range b.N { - _, err := store.GetAccountPureSQL(ctx, accountID) + _, err := store.GetAccount(ctx, accountID) if err != nil { b.Fatalf("GetAccountPureSQL failed: %v", err) } @@ -430,8 +530,8 @@ func TestAccountEquivalence(t *testing.T) { expectedF getAccountFunc actualF getAccountFunc }{ - {"old vs new", store.GetAccountSlow, store.GetAccount}, - {"old vs raw", store.GetAccountSlow, store.GetAccountPureSQL}, + {"old vs new", store.GetAccountSlow, store.GetAccountGormOpt}, + {"old vs raw", store.GetAccountSlow, store.GetAccount}, } for _, tt := range tests {