diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 865051fec..41d2af0c2 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2,7 +2,6 @@ package store import ( "context" - "database/sql" "encoding/json" "errors" "fmt" @@ -16,7 +15,6 @@ import ( "sync" "time" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" "gorm.io/driver/mysql" @@ -778,852 +776,6 @@ func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types. } func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { - var account types.Account - account.Network = &types.Network{} - const accountQuery = ` - SELECT - id, created_by, created_at, domain, domain_category, is_domain_primary_account, - -- Embedded Network - network_identifier, network_net, network_dns, network_serial, - -- Embedded DNSSettings - dns_settings_disabled_management_groups, - -- Embedded Settings - settings_peer_login_expiration_enabled, settings_peer_login_expiration, - settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration, - settings_regular_users_view_blocked, settings_groups_propagation_enabled, - settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups, - settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range, - settings_lazy_connection_enabled, - -- Embedded ExtraSettings - settings_extra_peer_approval_enabled, settings_extra_user_approval_required, - settings_extra_integrated_validator, settings_extra_integrated_validator_groups - FROM accounts WHERE id = $1` - - var networkNet, dnsSettingsDisabledGroups []byte - var ( - sPeerLoginExpirationEnabled sql.NullBool - sPeerLoginExpiration sql.NullInt64 - sPeerInactivityExpirationEnabled sql.NullBool - sPeerInactivityExpiration sql.NullInt64 - sRegularUsersViewBlocked sql.NullBool - sGroupsPropagationEnabled sql.NullBool - sJWTGroupsEnabled sql.NullBool - sJWTGroupsClaimName sql.NullString - sJWTAllowGroups []byte - sRoutingPeerDNSResolutionEnabled sql.NullBool - sDNSDomain sql.NullString - sNetworkRange []byte - sLazyConnectionEnabled sql.NullBool - sExtraPeerApprovalEnabled sql.NullBool - sExtraUserApprovalRequired sql.NullBool - sExtraIntegratedValidator sql.NullString - sExtraIntegratedValidatorGroups []byte - ) - - err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan( - &account.Id, &account.CreatedBy, &account.CreatedAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount, - &account.Network.Identifier, &networkNet, &account.Network.Dns, &account.Network.Serial, - &dnsSettingsDisabledGroups, - &sPeerLoginExpirationEnabled, &sPeerLoginExpiration, - &sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration, - &sRegularUsersViewBlocked, &sGroupsPropagationEnabled, - &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups, - &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange, - &sLazyConnectionEnabled, - &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired, - &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups, - ) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return nil, errors.New("account not found") - } - return nil, err - } - - _ = json.Unmarshal(networkNet, &account.Network.Net) - _ = json.Unmarshal(dnsSettingsDisabledGroups, &account.DNSSettings.DisabledManagementGroups) - - account.Settings = &types.Settings{Extra: &types.ExtraSettings{}} - if sPeerLoginExpirationEnabled.Valid { - account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool - } - if sPeerLoginExpiration.Valid { - account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64) - } - if sPeerInactivityExpirationEnabled.Valid { - account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool - } - if sPeerInactivityExpiration.Valid { - account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64) - } - if sRegularUsersViewBlocked.Valid { - account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool - } - if sGroupsPropagationEnabled.Valid { - account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool - } - if sJWTGroupsEnabled.Valid { - account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool - } - if sJWTGroupsClaimName.Valid { - account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String - } - if sRoutingPeerDNSResolutionEnabled.Valid { - account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool - } - if sDNSDomain.Valid { - account.Settings.DNSDomain = sDNSDomain.String - } - if sLazyConnectionEnabled.Valid { - account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool - } - if sJWTAllowGroups != nil { - _ = json.Unmarshal(sJWTAllowGroups, &account.Settings.JWTAllowGroups) - } - if sNetworkRange != nil { - _ = json.Unmarshal(sNetworkRange, &account.Settings.NetworkRange) - } - - if sExtraPeerApprovalEnabled.Valid { - account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool - } - if sExtraUserApprovalRequired.Valid { - account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool - } - if sExtraIntegratedValidator.Valid { - account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String - } - if sExtraIntegratedValidatorGroups != nil { - _ = json.Unmarshal(sExtraIntegratedValidatorGroups, &account.Settings.Extra.IntegratedValidatorGroups) - } - - var wg sync.WaitGroup - errChan := make(chan error, 12) - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - - keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { - var sk types.SetupKey - var autoGroups []byte - var expiresAt, updatedAt, lastUsed sql.NullTime - var revoked, ephemeral, allowExtraDNSLabels sql.NullBool - var usedTimes, usageLimit sql.NullInt64 - - err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels) - - if err == nil { - if expiresAt.Valid { - sk.ExpiresAt = &expiresAt.Time - } - if updatedAt.Valid { - sk.UpdatedAt = updatedAt.Time - if sk.UpdatedAt.IsZero() { - sk.UpdatedAt = sk.CreatedAt - } - } - if lastUsed.Valid { - sk.LastUsed = &lastUsed.Time - } - if revoked.Valid { - sk.Revoked = revoked.Bool - } - if usedTimes.Valid { - sk.UsedTimes = int(usedTimes.Int64) - } - if usageLimit.Valid { - sk.UsageLimit = int(usageLimit.Int64) - } - if ephemeral.Valid { - sk.Ephemeral = ephemeral.Bool - } - if allowExtraDNSLabels.Valid { - sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool - } - if autoGroups != nil { - _ = json.Unmarshal(autoGroups, &sk.AutoGroups) - } else { - sk.AutoGroups = []string{} - } - } - return sk, err - }) - if err != nil { - errChan <- err - return - } - account.SetupKeysG = keys - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, location_geo_name_id FROM peers WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - - peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { - var p nbpeer.Peer - p.Status = &nbpeer.PeerStatus{} - var lastLogin sql.NullTime - var sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool - var peerStatusLastSeen sql.NullTime - var peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool - var ip, extraDNS, netAddr, env, flags, files, connIP []byte - - err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &p.CreatedAt, &ephemeral, &extraDNS, &allowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) - - if err == nil { - if lastLogin.Valid { - p.LastLogin = &lastLogin.Time - } - if sshEnabled.Valid { - p.SSHEnabled = sshEnabled.Bool - } - if loginExpirationEnabled.Valid { - p.LoginExpirationEnabled = loginExpirationEnabled.Bool - } - if inactivityExpirationEnabled.Valid { - p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool - } - if ephemeral.Valid { - p.Ephemeral = ephemeral.Bool - } - if allowExtraDNSLabels.Valid { - p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool - } - if peerStatusLastSeen.Valid { - p.Status.LastSeen = peerStatusLastSeen.Time - } - if peerStatusConnected.Valid { - p.Status.Connected = peerStatusConnected.Bool - } - if peerStatusLoginExpired.Valid { - p.Status.LoginExpired = peerStatusLoginExpired.Bool - } - if peerStatusRequiresApproval.Valid { - p.Status.RequiresApproval = peerStatusRequiresApproval.Bool - } - - if ip != nil { - _ = json.Unmarshal(ip, &p.IP) - } - if extraDNS != nil { - _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) - } - if netAddr != nil { - _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses) - } - if env != nil { - _ = json.Unmarshal(env, &p.Meta.Environment) - } - if flags != nil { - _ = json.Unmarshal(flags, &p.Meta.Flags) - } - if files != nil { - _ = json.Unmarshal(files, &p.Meta.Files) - } - if connIP != nil { - _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) - } - } - return p, err - }) - if err != nil { - errChan <- err - return - } - account.PeersG = peers - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { - var u types.User - var autoGroups []byte - var lastLogin sql.NullTime - var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool - err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) - if err == nil { - if lastLogin.Valid { - u.LastLogin = &lastLogin.Time - } - if isServiceUser.Valid { - u.IsServiceUser = isServiceUser.Bool - } - if nonDeletable.Valid { - u.NonDeletable = nonDeletable.Bool - } - if blocked.Valid { - u.Blocked = blocked.Bool - } - if pendingApproval.Valid { - u.PendingApproval = pendingApproval.Bool - } - if autoGroups != nil { - _ = json.Unmarshal(autoGroups, &u.AutoGroups) - } else { - u.AutoGroups = []string{} - } - } - return u, err - }) - if err != nil { - errChan <- err - return - } - account.UsersG = users - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { - var g types.Group - var resources []byte - var refID sql.NullInt64 - var refType sql.NullString - err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) - if err == nil { - if refID.Valid { - g.IntegrationReference.ID = int(refID.Int64) - } - if refType.Valid { - g.IntegrationReference.IntegrationType = refType.String - } - if resources != nil { - _ = json.Unmarshal(resources, &g.Resources) - } else { - g.Resources = []types.Resource{} - } - g.GroupPeers = []types.GroupPeer{} - g.Peers = []string{} - } - return &g, err - }) - if err != nil { - errChan <- err - return - } - account.GroupsG = groups - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { - var p types.Policy - var checks []byte - var enabled sql.NullBool - err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) - if err == nil { - if enabled.Valid { - p.Enabled = enabled.Bool - } - if checks != nil { - _ = json.Unmarshal(checks, &p.SourcePostureChecks) - } - } - return &p, err - }) - if err != nil { - errChan <- err - return - } - account.Policies = policies - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { - var r route.Route - var network, domains, peerGroups, groups, accessGroups []byte - var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool - var metric sql.NullInt64 - err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) - if err == nil { - if keepRoute.Valid { - r.KeepRoute = keepRoute.Bool - } - if masquerade.Valid { - r.Masquerade = masquerade.Bool - } - if enabled.Valid { - r.Enabled = enabled.Bool - } - if skipAutoApply.Valid { - r.SkipAutoApply = skipAutoApply.Bool - } - if metric.Valid { - r.Metric = int(metric.Int64) - } - if network != nil { - _ = json.Unmarshal(network, &r.Network) - } - if domains != nil { - _ = json.Unmarshal(domains, &r.Domains) - } - if peerGroups != nil { - _ = json.Unmarshal(peerGroups, &r.PeerGroups) - } - if groups != nil { - _ = json.Unmarshal(groups, &r.Groups) - } - if accessGroups != nil { - _ = json.Unmarshal(accessGroups, &r.AccessControlGroups) - } - } - return r, err - }) - if err != nil { - errChan <- err - return - } - account.RoutesG = routes - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { - var n nbdns.NameServerGroup - var ns, groups, domains []byte - var primary, enabled, searchDomainsEnabled sql.NullBool - err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) - if err == nil { - if primary.Valid { - n.Primary = primary.Bool - } - if enabled.Valid { - n.Enabled = enabled.Bool - } - if searchDomainsEnabled.Valid { - n.SearchDomainsEnabled = searchDomainsEnabled.Bool - } - if ns != nil { - _ = json.Unmarshal(ns, &n.NameServers) - } else { - n.NameServers = []nbdns.NameServer{} - } - if groups != nil { - _ = json.Unmarshal(groups, &n.Groups) - } else { - n.Groups = []string{} - } - if domains != nil { - _ = json.Unmarshal(domains, &n.Domains) - } else { - n.Domains = []string{} - } - } - return n, err - }) - if err != nil { - errChan <- err - return - } - account.NameServerGroupsG = nsgs - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { - var c posture.Checks - var checksDef []byte - err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) - if err == nil && checksDef != nil { - _ = json.Unmarshal(checksDef, &c.Checks) - } - return &c, err - }) - if err != nil { - errChan <- err - return - } - account.PostureChecks = checks - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network]) - if err != nil { - errChan <- err - return - } - account.Networks = make([]*networkTypes.Network, len(networks)) - for i := range networks { - account.Networks[i] = &networks[i] - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { - var r routerTypes.NetworkRouter - var peerGroups []byte - var masquerade, enabled sql.NullBool - var metric sql.NullInt64 - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) - if err == nil { - if masquerade.Valid { - r.Masquerade = masquerade.Bool - } - if enabled.Valid { - r.Enabled = enabled.Bool - } - if metric.Valid { - r.Metric = int(metric.Int64) - } - if peerGroups != nil { - _ = json.Unmarshal(peerGroups, &r.PeerGroups) - } - } - return r, err - }) - if err != nil { - errChan <- err - return - } - account.NetworkRouters = make([]*routerTypes.NetworkRouter, len(routers)) - for i := range routers { - account.NetworkRouters[i] = &routers[i] - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { - var r resourceTypes.NetworkResource - var prefix []byte - var enabled sql.NullBool - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) - if err == nil { - if enabled.Valid { - r.Enabled = enabled.Bool - } - if prefix != nil { - _ = json.Unmarshal(prefix, &r.Prefix) - } - } - return r, err - }) - if err != nil { - errChan <- err - return - } - account.NetworkResources = make([]*resourceTypes.NetworkResource, len(resources)) - for i := range resources { - account.NetworkResources[i] = &resources[i] - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` - var onboardingFlowPending, signupFormPending sql.NullBool - err := s.pool.QueryRow(ctx, query, accountID).Scan( - &account.Onboarding.AccountID, - &onboardingFlowPending, - &signupFormPending, - &account.Onboarding.CreatedAt, - &account.Onboarding.UpdatedAt, - ) - if err != nil && !errors.Is(err, pgx.ErrNoRows) { - errChan <- err - return - } - if onboardingFlowPending.Valid { - account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool - } - if signupFormPending.Valid { - account.Onboarding.SignupFormPending = signupFormPending.Bool - } - }() - - wg.Wait() - close(errChan) - for e := range errChan { - if e != nil { - return nil, e - } - } - - var userIDs []string - for _, u := range account.UsersG { - userIDs = append(userIDs, u.Id) - } - var policyIDs []string - for _, p := range account.Policies { - policyIDs = append(policyIDs, p.ID) - } - var groupIDs []string - for _, g := range account.GroupsG { - groupIDs = append(groupIDs, g.ID) - } - - wg.Add(3) - errChan = make(chan error, 3) - - var pats []types.PersonalAccessToken - go func() { - defer wg.Done() - if len(userIDs) == 0 { - return - } - const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)` - rows, err := s.pool.Query(ctx, query, userIDs) - if err != nil { - errChan <- err - return - } - pats, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { - var pat types.PersonalAccessToken - var expirationDate, lastUsed sql.NullTime - err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &pat.CreatedAt, &lastUsed) - if err == nil { - if expirationDate.Valid { - pat.ExpirationDate = &expirationDate.Time - } - if lastUsed.Valid { - pat.LastUsed = &lastUsed.Time - } - } - return pat, err - }) - if err != nil { - errChan <- err - } - }() - - var rules []*types.PolicyRule - go func() { - defer wg.Done() - if len(policyIDs) == 0 { - return - } - const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` - rows, err := s.pool.Query(ctx, query, policyIDs) - if err != nil { - errChan <- err - return - } - rules, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { - var r types.PolicyRule - var dest, destRes, sources, sourceRes, ports, portRanges []byte - var enabled, bidirectional sql.NullBool - err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) - if err == nil { - if enabled.Valid { - r.Enabled = enabled.Bool - } - if bidirectional.Valid { - r.Bidirectional = bidirectional.Bool - } - if dest != nil { - _ = json.Unmarshal(dest, &r.Destinations) - } - if destRes != nil { - _ = json.Unmarshal(destRes, &r.DestinationResource) - } - if sources != nil { - _ = json.Unmarshal(sources, &r.Sources) - } - if sourceRes != nil { - _ = json.Unmarshal(sourceRes, &r.SourceResource) - } - if ports != nil { - _ = json.Unmarshal(ports, &r.Ports) - } - if portRanges != nil { - _ = json.Unmarshal(portRanges, &r.PortRanges) - } - } - return &r, err - }) - if err != nil { - errChan <- err - } - }() - - var groupPeers []types.GroupPeer - go func() { - defer wg.Done() - if len(groupIDs) == 0 { - return - } - const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)` - rows, err := s.pool.Query(ctx, query, groupIDs) - if err != nil { - errChan <- err - return - } - groupPeers, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer]) - if err != nil { - errChan <- err - } - }() - - wg.Wait() - close(errChan) - for e := range errChan { - if e != nil { - return nil, e - } - } - - patsByUserID := make(map[string][]*types.PersonalAccessToken) - for i := range pats { - pat := &pats[i] - patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) - pat.UserID = "" - } - - rulesByPolicyID := make(map[string][]*types.PolicyRule) - for _, rule := range rules { - rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) - } - - peersByGroupID := make(map[string][]string) - for _, gp := range groupPeers { - peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) - } - - account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) - for i := range account.SetupKeysG { - key := &account.SetupKeysG[i] - account.SetupKeys[key.Key] = key - } - - account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) - for i := range account.PeersG { - peer := &account.PeersG[i] - account.Peers[peer.ID] = peer - } - - account.Users = make(map[string]*types.User, len(account.UsersG)) - for i := range account.UsersG { - user := &account.UsersG[i] - user.PATs = make(map[string]*types.PersonalAccessToken) - if userPats, ok := patsByUserID[user.Id]; ok { - for j := range userPats { - pat := userPats[j] - user.PATs[pat.ID] = pat - } - } - account.Users[user.Id] = user - } - - for i := range account.Policies { - policy := account.Policies[i] - if policyRules, ok := rulesByPolicyID[policy.ID]; ok { - policy.Rules = policyRules - } - } - - account.Groups = make(map[string]*types.Group, len(account.GroupsG)) - for i := range account.GroupsG { - group := account.GroupsG[i] - if peerIDs, ok := peersByGroupID[group.ID]; ok { - group.Peers = peerIDs - } - account.Groups[group.ID] = group - } - - account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) - for i := range account.RoutesG { - route := &account.RoutesG[i] - account.Routes[route.ID] = route - } - - account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) - for i := range account.NameServerGroupsG { - nsg := &account.NameServerGroupsG[i] - nsg.AccountID = "" - account.NameServerGroups[nsg.ID] = nsg - } - - account.SetupKeysG = nil - account.PeersG = nil - account.UsersG = nil - account.GroupsG = nil - account.RoutesG = nil - account.NameServerGroupsG = nil - - return &account, nil -} - -func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { elapsed := time.Since(start) @@ -1634,8 +786,7 @@ func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types. var account types.Account result := s.db.Model(&account). - // Omit("GroupsG"). - Preload("UsersG.PATsG"). // have to be specifies as this is nester reference + Preload("UsersG.PATsG"). Preload("Policies.Rules"). Preload("SetupKeysG"). Preload("PeersG"). @@ -1657,21 +808,14 @@ func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types. return nil, status.NewGetAccountFromStoreError(result.Error) } - // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us - // for i, policy := range account.Policies { - // var rules []*types.PolicyRule - // err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error - // if err != nil { - // return nil, status.Errorf(status.NotFound, "rule not found") - // } - // account.Policies[i].Rules = rules - // } - account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) for _, key := range account.SetupKeysG { if key.UpdatedAt.IsZero() { key.UpdatedAt = key.CreatedAt } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } account.SetupKeys[key.Key] = &key } account.SetupKeysG = nil @@ -1689,6 +833,9 @@ func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types. pat.UserID = "" user.PATs[pat.ID] = &pat } + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } account.Users[user.Id] = &user user.PATsG = nil } @@ -1700,21 +847,13 @@ func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types. for i, gp := range group.GroupPeers { group.Peers[i] = gp.PeerID } + if group.Resources == nil { + group.Resources = []types.Resource{} + } account.Groups[group.ID] = group } account.GroupsG = nil - // var groupPeers []types.GroupPeer - // s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). - // Find(&groupPeers) - // for _, groupPeer := range groupPeers { - // if group, ok := account.Groups[groupPeer.GroupID]; ok { - // group.Peers = append(group.Peers, groupPeer.PeerID) - // } else { - // log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) - // } - // } - account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) for _, route := range account.RoutesG { account.Routes[route.ID] = &route @@ -1724,6 +863,15 @@ func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types. account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) for _, ns := range account.NameServerGroupsG { ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } account.NameServerGroups[ns.ID] = &ns } account.NameServerGroupsG = nil @@ -2070,7 +1218,7 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe if err != nil { return nil, err } - pool, err := connectDB(context.Background(), dsn) + pool, err := connectToPgDb(context.Background(), dsn) if err != nil { return nil, err } @@ -2083,7 +1231,7 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe return store, nil } -func connectDB(ctx context.Context, dsn string) (*pgxpool.Pool, error) { +func connectToPgDb(ctx context.Context, dsn string) (*pgxpool.Pool, error) { config, err := pgxpool.ParseConfig(dsn) if err != nil { return nil, fmt.Errorf("unable to parse database config: %w", err) diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index 2b0256bd8..8c94f53c4 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -30,7 +30,6 @@ import ( "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/status" ) @@ -153,7 +152,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, string) { b.Fatalf("failed to connect database: %v", err) } - pool, err := connectDB(context.Background(), dsn) + pool, err := connectDBforTest(context.Background(), dsn) if err != nil { b.Fatalf("failed to connect database: %v", err) } @@ -434,7 +433,7 @@ func TestAccountEquivalence(t *testing.T) { expectedF getAccountFunc actualF getAccountFunc }{ - // {"old vs new", store.GetAccountSlow, store.GetAccount}, + {"old vs new", store.GetAccountSlow, store.GetAccount}, {"old vs raw", store.GetAccountSlow, store.GetAccountPureSQL}, } @@ -461,9 +460,6 @@ func TestAccountEquivalence(t *testing.T) { } func testAccountEquivalence(t *testing.T, expected, actual *types.Account) { - // normalizeNilSlices(expected) - // normalizeNilSlices(actual) - assert.Equal(t, expected.Id, actual.Id, "Account IDs should be equal") assert.Equal(t, expected.CreatedBy, actual.CreatedBy, "Account CreatedBy fields should be equal") assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second, "Account CreatedAt timestamps should be within a second") @@ -566,136 +562,290 @@ func testAccountEquivalence(t *testing.T, expected, actual *types.Account) { } } -func normalizeNilSlices(acc *types.Account) { - if acc == nil { - return +func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) { + account, err := s.getAccount(ctx, accountID) + if err != nil { + return nil, err } - if acc.Policies == nil { - acc.Policies = []*types.Policy{} - } - if acc.PostureChecks == nil { - acc.PostureChecks = []*posture.Checks{} - } - if acc.Networks == nil { - acc.Networks = []*networkTypes.Network{} - } - if acc.NetworkRouters == nil { - acc.NetworkRouters = []*routerTypes.NetworkRouter{} - } - if acc.NetworkResources == nil { - acc.NetworkResources = []*resourceTypes.NetworkResource{} - } - if acc.DNSSettings.DisabledManagementGroups == nil { - acc.DNSSettings.DisabledManagementGroups = []string{} - } + var wg sync.WaitGroup + errChan := make(chan error, 12) - for _, key := range acc.SetupKeys { - if key.AutoGroups == nil { - key.AutoGroups = []string{} + wg.Add(1) + go func() { + defer wg.Done() + keys, err := s.getSetupKeys(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + peers, err := s.getPeers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + users, err := s.getUsers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + groups, err := s.getGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + policies, err := s.getPolicies(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + routes, err := s.getRoutes(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + nsgs, err := s.getNameServerGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + checks, err := s.getPostureChecks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + networks, err := s.getNetworks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Networks = networks + }() + + wg.Add(1) + go func() { + defer wg.Done() + routers, err := s.getNetworkRouters(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + resources, err := s.getNetworkResources(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.getAccountOnboarding(ctx, accountID, account) + if err != nil { + errChan <- err + return + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e } } - for _, peer := range acc.Peers { - if peer.ExtraDNSLabels == nil { - peer.ExtraDNSLabels = []string{} + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + var err error + pats, err = s.getPersonalAccessTokens(ctx, userIDs) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + var err error + rules, err = s.getPolicyRules(ctx, policyIDs) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + var err error + groupPeers, err = s.getGroupPeers(ctx, groupIDs) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e } } - for _, user := range acc.Users { - if user.AutoGroups == nil { - user.AutoGroups = []string{} - } + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" } - for _, group := range acc.Groups { - if group.Peers == nil { - group.Peers = []string{} - } - if group.Resources == nil { - group.Resources = []types.Resource{} - } - if group.GroupPeers == nil { - group.GroupPeers = []types.GroupPeer{} - } + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) } - for _, route := range acc.Routes { - if route.Domains == nil { - route.Domains = domain.List{} - } - if route.PeerGroups == nil { - route.PeerGroups = []string{} - } - if route.Groups == nil { - route.Groups = []string{} - } - if route.AccessControlGroups == nil { - route.AccessControlGroups = []string{} - } + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) } - for _, nsg := range acc.NameServerGroups { - if nsg.NameServers == nil { - nsg.NameServers = []nbdns.NameServer{} - } - if nsg.Groups == nil { - nsg.Groups = []string{} - } - if nsg.Domains == nil { - nsg.Domains = []string{} - } + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key } - for _, policy := range acc.Policies { - if policy.SourcePostureChecks == nil { - policy.SourcePostureChecks = []string{} - } - if policy.Rules == nil { - policy.Rules = []*types.PolicyRule{} - } - for _, rule := range policy.Rules { - if rule.Destinations == nil { - rule.Destinations = []string{} - } - if rule.Sources == nil { - rule.Sources = []string{} - } - if rule.Ports == nil { - rule.Ports = []string{} - } - if rule.PortRanges == nil { - rule.PortRanges = []types.RulePortRange{} + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat } } + account.Users[user.Id] = user } - for _, check := range acc.PostureChecks { - if check.Checks.GeoLocationCheck != nil { - if check.Checks.GeoLocationCheck.Locations == nil { - check.Checks.GeoLocationCheck.Locations = []posture.Location{} - } - } - if check.Checks.PeerNetworkRangeCheck != nil { - if check.Checks.PeerNetworkRangeCheck.Ranges == nil { - check.Checks.PeerNetworkRangeCheck.Ranges = []netip.Prefix{} - } - } - if check.Checks.ProcessCheck != nil { - if check.Checks.ProcessCheck.Processes == nil { - check.Checks.ProcessCheck.Processes = []posture.Process{} - } + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules } } - for _, router := range acc.NetworkRouters { - if router.PeerGroups == nil { - router.PeerGroups = []string{} + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs } + account.Groups[group.ID] = group } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route + } + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg + } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil + account.NameServerGroupsG = nil + + return account, nil } -func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) { +func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Account, error) { var account types.Account account.Network = &types.Network{} const accountQuery = ` @@ -814,729 +964,566 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty if sExtraIntegratedValidatorGroups != nil { _ = json.Unmarshal(sExtraIntegratedValidatorGroups, &account.Settings.Extra.IntegratedValidatorGroups) } - - var wg sync.WaitGroup - errChan := make(chan error, 12) - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - - keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { - var sk types.SetupKey - var autoGroups []byte - var expiresAt, updatedAt, lastUsed sql.NullTime - var revoked, ephemeral, allowExtraDNSLabels sql.NullBool - var usedTimes, usageLimit sql.NullInt64 - - err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels) - - if err == nil { - if expiresAt.Valid { - sk.ExpiresAt = &expiresAt.Time - } - if updatedAt.Valid { - sk.UpdatedAt = updatedAt.Time - if sk.UpdatedAt.IsZero() { - sk.UpdatedAt = sk.CreatedAt - } - } - if lastUsed.Valid { - sk.LastUsed = &lastUsed.Time - } - if revoked.Valid { - sk.Revoked = revoked.Bool - } - if usedTimes.Valid { - sk.UsedTimes = int(usedTimes.Int64) - } - if usageLimit.Valid { - sk.UsageLimit = int(usageLimit.Int64) - } - if ephemeral.Valid { - sk.Ephemeral = ephemeral.Bool - } - if allowExtraDNSLabels.Valid { - sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool - } - if autoGroups != nil { - _ = json.Unmarshal(autoGroups, &sk.AutoGroups) - } else { - sk.AutoGroups = []string{} - } - } - return sk, err - }) - if err != nil { - errChan <- err - return - } - account.SetupKeysG = keys - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, location_geo_name_id FROM peers WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - - peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { - var p nbpeer.Peer - p.Status = &nbpeer.PeerStatus{} - var lastLogin sql.NullTime - var sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool - var peerStatusLastSeen sql.NullTime - var peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool - var ip, extraDNS, netAddr, env, flags, files, connIP []byte - - err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &p.CreatedAt, &ephemeral, &extraDNS, &allowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) - - if err == nil { - if lastLogin.Valid { - p.LastLogin = &lastLogin.Time - } - if sshEnabled.Valid { - p.SSHEnabled = sshEnabled.Bool - } - if loginExpirationEnabled.Valid { - p.LoginExpirationEnabled = loginExpirationEnabled.Bool - } - if inactivityExpirationEnabled.Valid { - p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool - } - if ephemeral.Valid { - p.Ephemeral = ephemeral.Bool - } - if allowExtraDNSLabels.Valid { - p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool - } - if peerStatusLastSeen.Valid { - p.Status.LastSeen = peerStatusLastSeen.Time - } - if peerStatusConnected.Valid { - p.Status.Connected = peerStatusConnected.Bool - } - if peerStatusLoginExpired.Valid { - p.Status.LoginExpired = peerStatusLoginExpired.Bool - } - if peerStatusRequiresApproval.Valid { - p.Status.RequiresApproval = peerStatusRequiresApproval.Bool - } - - if ip != nil { - _ = json.Unmarshal(ip, &p.IP) - } - if extraDNS != nil { - _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) - } - if netAddr != nil { - _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses) - } - if env != nil { - _ = json.Unmarshal(env, &p.Meta.Environment) - } - if flags != nil { - _ = json.Unmarshal(flags, &p.Meta.Flags) - } - if files != nil { - _ = json.Unmarshal(files, &p.Meta.Files) - } - if connIP != nil { - _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) - } - } - return p, err - }) - if err != nil { - errChan <- err - return - } - account.PeersG = peers - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { - var u types.User - var autoGroups []byte - var lastLogin sql.NullTime - var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool - err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) - if err == nil { - if lastLogin.Valid { - u.LastLogin = &lastLogin.Time - } - if isServiceUser.Valid { - u.IsServiceUser = isServiceUser.Bool - } - if nonDeletable.Valid { - u.NonDeletable = nonDeletable.Bool - } - if blocked.Valid { - u.Blocked = blocked.Bool - } - if pendingApproval.Valid { - u.PendingApproval = pendingApproval.Bool - } - if autoGroups != nil { - _ = json.Unmarshal(autoGroups, &u.AutoGroups) - } else { - u.AutoGroups = []string{} - } - } - return u, err - }) - if err != nil { - errChan <- err - return - } - account.UsersG = users - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { - var g types.Group - var resources []byte - var refID sql.NullInt64 - var refType sql.NullString - err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) - if err == nil { - if refID.Valid { - g.IntegrationReference.ID = int(refID.Int64) - } - if refType.Valid { - g.IntegrationReference.IntegrationType = refType.String - } - if resources != nil { - _ = json.Unmarshal(resources, &g.Resources) - } else { - g.Resources = []types.Resource{} - } - g.GroupPeers = []types.GroupPeer{} - g.Peers = []string{} - } - return &g, err - }) - if err != nil { - errChan <- err - return - } - account.GroupsG = groups - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { - var p types.Policy - var checks []byte - var enabled sql.NullBool - err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) - if err == nil { - if enabled.Valid { - p.Enabled = enabled.Bool - } - if checks != nil { - _ = json.Unmarshal(checks, &p.SourcePostureChecks) - } - } - return &p, err - }) - if err != nil { - errChan <- err - return - } - account.Policies = policies - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { - var r route.Route - var network, domains, peerGroups, groups, accessGroups []byte - var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool - var metric sql.NullInt64 - err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) - if err == nil { - if keepRoute.Valid { - r.KeepRoute = keepRoute.Bool - } - if masquerade.Valid { - r.Masquerade = masquerade.Bool - } - if enabled.Valid { - r.Enabled = enabled.Bool - } - if skipAutoApply.Valid { - r.SkipAutoApply = skipAutoApply.Bool - } - if metric.Valid { - r.Metric = int(metric.Int64) - } - if network != nil { - _ = json.Unmarshal(network, &r.Network) - } - if domains != nil { - _ = json.Unmarshal(domains, &r.Domains) - } - if peerGroups != nil { - _ = json.Unmarshal(peerGroups, &r.PeerGroups) - } - if groups != nil { - _ = json.Unmarshal(groups, &r.Groups) - } - if accessGroups != nil { - _ = json.Unmarshal(accessGroups, &r.AccessControlGroups) - } - } - return r, err - }) - if err != nil { - errChan <- err - return - } - account.RoutesG = routes - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { - var n nbdns.NameServerGroup - var ns, groups, domains []byte - var primary, enabled, searchDomainsEnabled sql.NullBool - err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) - if err == nil { - if primary.Valid { - n.Primary = primary.Bool - } - if enabled.Valid { - n.Enabled = enabled.Bool - } - if searchDomainsEnabled.Valid { - n.SearchDomainsEnabled = searchDomainsEnabled.Bool - } - if ns != nil { - _ = json.Unmarshal(ns, &n.NameServers) - } else { - n.NameServers = []nbdns.NameServer{} - } - if groups != nil { - _ = json.Unmarshal(groups, &n.Groups) - } else { - n.Groups = []string{} - } - if domains != nil { - _ = json.Unmarshal(domains, &n.Domains) - } else { - n.Domains = []string{} - } - } - return n, err - }) - if err != nil { - errChan <- err - return - } - account.NameServerGroupsG = nsgs - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { - var c posture.Checks - var checksDef []byte - err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) - if err == nil && checksDef != nil { - _ = json.Unmarshal(checksDef, &c.Checks) - } - return &c, err - }) - if err != nil { - errChan <- err - return - } - account.PostureChecks = checks - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network]) - if err != nil { - errChan <- err - return - } - account.Networks = make([]*networkTypes.Network, len(networks)) - for i := range networks { - account.Networks[i] = &networks[i] - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { - var r routerTypes.NetworkRouter - var peerGroups []byte - var masquerade, enabled sql.NullBool - var metric sql.NullInt64 - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) - if err == nil { - if masquerade.Valid { - r.Masquerade = masquerade.Bool - } - if enabled.Valid { - r.Enabled = enabled.Bool - } - if metric.Valid { - r.Metric = int(metric.Int64) - } - if peerGroups != nil { - _ = json.Unmarshal(peerGroups, &r.PeerGroups) - } - } - return r, err - }) - if err != nil { - errChan <- err - return - } - account.NetworkRouters = make([]*routerTypes.NetworkRouter, len(routers)) - for i := range routers { - account.NetworkRouters[i] = &routers[i] - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) - if err != nil { - errChan <- err - return - } - resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { - var r resourceTypes.NetworkResource - var prefix []byte - var enabled sql.NullBool - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) - if err == nil { - if enabled.Valid { - r.Enabled = enabled.Bool - } - if prefix != nil { - _ = json.Unmarshal(prefix, &r.Prefix) - } - } - return r, err - }) - if err != nil { - errChan <- err - return - } - account.NetworkResources = make([]*resourceTypes.NetworkResource, len(resources)) - for i := range resources { - account.NetworkResources[i] = &resources[i] - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` - var onboardingFlowPending, signupFormPending sql.NullBool - err := s.pool.QueryRow(ctx, query, accountID).Scan( - &account.Onboarding.AccountID, - &onboardingFlowPending, - &signupFormPending, - &account.Onboarding.CreatedAt, - &account.Onboarding.UpdatedAt, - ) - if err != nil && !errors.Is(err, pgx.ErrNoRows) { - errChan <- err - return - } - if onboardingFlowPending.Valid { - account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool - } - if signupFormPending.Valid { - account.Onboarding.SignupFormPending = signupFormPending.Bool - } - }() - - wg.Wait() - close(errChan) - for e := range errChan { - if e != nil { - return nil, e - } - } - - var userIDs []string - for _, u := range account.UsersG { - userIDs = append(userIDs, u.Id) - } - var policyIDs []string - for _, p := range account.Policies { - policyIDs = append(policyIDs, p.ID) - } - var groupIDs []string - for _, g := range account.GroupsG { - groupIDs = append(groupIDs, g.ID) - } - - wg.Add(3) - errChan = make(chan error, 3) - - var pats []types.PersonalAccessToken - go func() { - defer wg.Done() - if len(userIDs) == 0 { - return - } - const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)` - rows, err := s.pool.Query(ctx, query, userIDs) - if err != nil { - errChan <- err - return - } - pats, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { - var pat types.PersonalAccessToken - var expirationDate, lastUsed sql.NullTime - err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &pat.CreatedAt, &lastUsed) - if err == nil { - if expirationDate.Valid { - pat.ExpirationDate = &expirationDate.Time - } - if lastUsed.Valid { - pat.LastUsed = &lastUsed.Time - } - } - return pat, err - }) - if err != nil { - errChan <- err - } - }() - - var rules []*types.PolicyRule - go func() { - defer wg.Done() - if len(policyIDs) == 0 { - return - } - const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` - rows, err := s.pool.Query(ctx, query, policyIDs) - if err != nil { - errChan <- err - return - } - rules, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { - var r types.PolicyRule - var dest, destRes, sources, sourceRes, ports, portRanges []byte - var enabled, bidirectional sql.NullBool - err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) - if err == nil { - if enabled.Valid { - r.Enabled = enabled.Bool - } - if bidirectional.Valid { - r.Bidirectional = bidirectional.Bool - } - if dest != nil { - _ = json.Unmarshal(dest, &r.Destinations) - } - if destRes != nil { - _ = json.Unmarshal(destRes, &r.DestinationResource) - } - if sources != nil { - _ = json.Unmarshal(sources, &r.Sources) - } - if sourceRes != nil { - _ = json.Unmarshal(sourceRes, &r.SourceResource) - } - if ports != nil { - _ = json.Unmarshal(ports, &r.Ports) - } - if portRanges != nil { - _ = json.Unmarshal(portRanges, &r.PortRanges) - } - } - return &r, err - }) - if err != nil { - errChan <- err - } - }() - - var groupPeers []types.GroupPeer - go func() { - defer wg.Done() - if len(groupIDs) == 0 { - return - } - const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)` - rows, err := s.pool.Query(ctx, query, groupIDs) - if err != nil { - errChan <- err - return - } - groupPeers, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer]) - if err != nil { - errChan <- err - } - }() - - wg.Wait() - close(errChan) - for e := range errChan { - if e != nil { - return nil, e - } - } - - patsByUserID := make(map[string][]*types.PersonalAccessToken) - for i := range pats { - pat := &pats[i] - patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) - pat.UserID = "" - } - - rulesByPolicyID := make(map[string][]*types.PolicyRule) - for _, rule := range rules { - rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) - } - - peersByGroupID := make(map[string][]string) - for _, gp := range groupPeers { - peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) - } - - account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) - for i := range account.SetupKeysG { - key := &account.SetupKeysG[i] - account.SetupKeys[key.Key] = key - } - - account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) - for i := range account.PeersG { - peer := &account.PeersG[i] - account.Peers[peer.ID] = peer - } - - account.Users = make(map[string]*types.User, len(account.UsersG)) - for i := range account.UsersG { - user := &account.UsersG[i] - user.PATs = make(map[string]*types.PersonalAccessToken) - if userPats, ok := patsByUserID[user.Id]; ok { - for j := range userPats { - pat := userPats[j] - user.PATs[pat.ID] = pat - } - } - account.Users[user.Id] = user - } - - for i := range account.Policies { - policy := account.Policies[i] - if policyRules, ok := rulesByPolicyID[policy.ID]; ok { - policy.Rules = policyRules - } - } - - account.Groups = make(map[string]*types.Group, len(account.GroupsG)) - for i := range account.GroupsG { - group := account.GroupsG[i] - if peerIDs, ok := peersByGroupID[group.ID]; ok { - group.Peers = peerIDs - } - account.Groups[group.ID] = group - } - - account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) - for i := range account.RoutesG { - route := &account.RoutesG[i] - account.Routes[route.ID] = route - } - - account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) - for i := range account.NameServerGroupsG { - nsg := &account.NameServerGroupsG[i] - nsg.AccountID = "" - account.NameServerGroups[nsg.ID] = nsg - } - - account.SetupKeysG = nil - account.PeersG = nil - account.UsersG = nil - account.GroupsG = nil - account.RoutesG = nil - account.NameServerGroupsG = nil - return &account, nil } + +func (s *SqlStore) getSetupKeys(ctx context.Context, accountID string) ([]types.SetupKey, error) { + const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { + var sk types.SetupKey + var autoGroups []byte + var expiresAt, updatedAt, lastUsed sql.NullTime + var revoked, ephemeral, allowExtraDNSLabels sql.NullBool + var usedTimes, usageLimit sql.NullInt64 + + err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels) + + if err == nil { + if expiresAt.Valid { + sk.ExpiresAt = &expiresAt.Time + } + if updatedAt.Valid { + sk.UpdatedAt = updatedAt.Time + if sk.UpdatedAt.IsZero() { + sk.UpdatedAt = sk.CreatedAt + } + } + if lastUsed.Valid { + sk.LastUsed = &lastUsed.Time + } + if revoked.Valid { + sk.Revoked = revoked.Bool + } + if usedTimes.Valid { + sk.UsedTimes = int(usedTimes.Int64) + } + if usageLimit.Valid { + sk.UsageLimit = int(usageLimit.Int64) + } + if ephemeral.Valid { + sk.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &sk.AutoGroups) + } else { + sk.AutoGroups = []string{} + } + } + return sk, err + }) + if err != nil { + return nil, err + } + return keys, nil +} + +func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Peer, error) { + const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, location_geo_name_id FROM peers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { + var p nbpeer.Peer + p.Status = &nbpeer.PeerStatus{} + var lastLogin sql.NullTime + var sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool + var peerStatusLastSeen sql.NullTime + var peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool + var ip, extraDNS, netAddr, env, flags, files, connIP []byte + + err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &p.CreatedAt, &ephemeral, &extraDNS, &allowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) + + if err == nil { + if lastLogin.Valid { + p.LastLogin = &lastLogin.Time + } + if sshEnabled.Valid { + p.SSHEnabled = sshEnabled.Bool + } + if loginExpirationEnabled.Valid { + p.LoginExpirationEnabled = loginExpirationEnabled.Bool + } + if inactivityExpirationEnabled.Valid { + p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool + } + if ephemeral.Valid { + p.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if peerStatusLastSeen.Valid { + p.Status.LastSeen = peerStatusLastSeen.Time + } + if peerStatusConnected.Valid { + p.Status.Connected = peerStatusConnected.Bool + } + if peerStatusLoginExpired.Valid { + p.Status.LoginExpired = peerStatusLoginExpired.Bool + } + if peerStatusRequiresApproval.Valid { + p.Status.RequiresApproval = peerStatusRequiresApproval.Bool + } + + if ip != nil { + _ = json.Unmarshal(ip, &p.IP) + } + if extraDNS != nil { + _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) + } + if netAddr != nil { + _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses) + } + if env != nil { + _ = json.Unmarshal(env, &p.Meta.Environment) + } + if flags != nil { + _ = json.Unmarshal(flags, &p.Meta.Flags) + } + if files != nil { + _ = json.Unmarshal(files, &p.Meta.Files) + } + if connIP != nil { + _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) + } + } + return p, err + }) + if err != nil { + return nil, err + } + return peers, nil +} + +func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) { + const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { + var u types.User + var autoGroups []byte + var lastLogin sql.NullTime + var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool + err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) + if err == nil { + if lastLogin.Valid { + u.LastLogin = &lastLogin.Time + } + if isServiceUser.Valid { + u.IsServiceUser = isServiceUser.Bool + } + if nonDeletable.Valid { + u.NonDeletable = nonDeletable.Bool + } + if blocked.Valid { + u.Blocked = blocked.Bool + } + if pendingApproval.Valid { + u.PendingApproval = pendingApproval.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &u.AutoGroups) + } else { + u.AutoGroups = []string{} + } + } + return u, err + }) + if err != nil { + return nil, err + } + return users, nil +} + +func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) { + const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { + var g types.Group + var resources []byte + var refID sql.NullInt64 + var refType sql.NullString + err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) + if err == nil { + if refID.Valid { + g.IntegrationReference.ID = int(refID.Int64) + } + if refType.Valid { + g.IntegrationReference.IntegrationType = refType.String + } + if resources != nil { + _ = json.Unmarshal(resources, &g.Resources) + } else { + g.Resources = []types.Resource{} + } + g.GroupPeers = []types.GroupPeer{} + g.Peers = []string{} + } + return &g, err + }) + if err != nil { + return nil, err + } + return groups, nil +} + +func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) { + const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { + var p types.Policy + var checks []byte + var enabled sql.NullBool + err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) + if err == nil { + if enabled.Valid { + p.Enabled = enabled.Bool + } + if checks != nil { + _ = json.Unmarshal(checks, &p.SourcePostureChecks) + } + } + return &p, err + }) + if err != nil { + return nil, err + } + return policies, nil +} + +func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) { + const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { + var r route.Route + var network, domains, peerGroups, groups, accessGroups []byte + var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) + if err == nil { + if keepRoute.Valid { + r.KeepRoute = keepRoute.Bool + } + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if skipAutoApply.Valid { + r.SkipAutoApply = skipAutoApply.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if network != nil { + _ = json.Unmarshal(network, &r.Network) + } + if domains != nil { + _ = json.Unmarshal(domains, &r.Domains) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + if groups != nil { + _ = json.Unmarshal(groups, &r.Groups) + } + if accessGroups != nil { + _ = json.Unmarshal(accessGroups, &r.AccessControlGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + return routes, nil +} + +func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) { + const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { + var n nbdns.NameServerGroup + var ns, groups, domains []byte + var primary, enabled, searchDomainsEnabled sql.NullBool + err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) + if err == nil { + if primary.Valid { + n.Primary = primary.Bool + } + if enabled.Valid { + n.Enabled = enabled.Bool + } + if searchDomainsEnabled.Valid { + n.SearchDomainsEnabled = searchDomainsEnabled.Bool + } + if ns != nil { + _ = json.Unmarshal(ns, &n.NameServers) + } else { + n.NameServers = []nbdns.NameServer{} + } + if groups != nil { + _ = json.Unmarshal(groups, &n.Groups) + } else { + n.Groups = []string{} + } + if domains != nil { + _ = json.Unmarshal(domains, &n.Domains) + } else { + n.Domains = []string{} + } + } + return n, err + }) + if err != nil { + return nil, err + } + return nsgs, nil +} + +func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) { + const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { + var c posture.Checks + var checksDef []byte + err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) + if err == nil && checksDef != nil { + _ = json.Unmarshal(checksDef, &c.Checks) + } + return &c, err + }) + if err != nil { + return nil, err + } + return checks, nil +} + +func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) { + const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network]) + if err != nil { + return nil, err + } + result := make([]*networkTypes.Network, len(networks)) + for i := range networks { + result[i] = &networks[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) { + const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { + var r routerTypes.NetworkRouter + var peerGroups []byte + var masquerade, enabled sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) + if err == nil { + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*routerTypes.NetworkRouter, len(routers)) + for i := range routers { + result[i] = &routers[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) { + const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { + var r resourceTypes.NetworkResource + var prefix []byte + var enabled sql.NullBool + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if prefix != nil { + _ = json.Unmarshal(prefix, &r.Prefix) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*resourceTypes.NetworkResource, len(resources)) + for i := range resources { + result[i] = &resources[i] + } + return result, nil +} + +func (s *SqlStore) getAccountOnboarding(ctx context.Context, accountID string, account *types.Account) error { + const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` + var onboardingFlowPending, signupFormPending sql.NullBool + err := s.pool.QueryRow(ctx, query, accountID).Scan( + &account.Onboarding.AccountID, + &onboardingFlowPending, + &signupFormPending, + &account.Onboarding.CreatedAt, + &account.Onboarding.UpdatedAt, + ) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return err + } + if onboardingFlowPending.Valid { + account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool + } + if signupFormPending.Valid { + account.Onboarding.SignupFormPending = signupFormPending.Bool + } + return nil +} + +func (s *SqlStore) getPersonalAccessTokens(ctx context.Context, userIDs []string) ([]types.PersonalAccessToken, error) { + if len(userIDs) == 0 { + return nil, nil + } + const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, userIDs) + if err != nil { + return nil, err + } + pats, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { + var pat types.PersonalAccessToken + var expirationDate, lastUsed sql.NullTime + err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &pat.CreatedAt, &lastUsed) + if err == nil { + if expirationDate.Valid { + pat.ExpirationDate = &expirationDate.Time + } + if lastUsed.Valid { + pat.LastUsed = &lastUsed.Time + } + } + return pat, err + }) + if err != nil { + return nil, err + } + return pats, nil +} + +func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*types.PolicyRule, error) { + if len(policyIDs) == 0 { + return nil, nil + } + const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, policyIDs) + if err != nil { + return nil, err + } + rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { + var r types.PolicyRule + var dest, destRes, sources, sourceRes, ports, portRanges []byte + var enabled, bidirectional sql.NullBool + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if bidirectional.Valid { + r.Bidirectional = bidirectional.Bool + } + if dest != nil { + _ = json.Unmarshal(dest, &r.Destinations) + } + if destRes != nil { + _ = json.Unmarshal(destRes, &r.DestinationResource) + } + if sources != nil { + _ = json.Unmarshal(sources, &r.Sources) + } + if sourceRes != nil { + _ = json.Unmarshal(sourceRes, &r.SourceResource) + } + if ports != nil { + _ = json.Unmarshal(ports, &r.Ports) + } + if portRanges != nil { + _ = json.Unmarshal(portRanges, &r.PortRanges) + } + } + return &r, err + }) + if err != nil { + return nil, err + } + return rules, nil +} + +func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]types.GroupPeer, error) { + if len(groupIDs) == 0 { + return nil, nil + } + const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, groupIDs) + if err != nil { + return nil, err + } + groupPeers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer]) + if err != nil { + return nil, err + } + return groupPeers, nil +}