new raw sql get account method

This commit is contained in:
crn4
2025-10-16 23:54:17 +02:00
parent cd9a867ad0
commit 5e79cc0176
2 changed files with 641 additions and 28 deletions

View File

@@ -15,6 +15,8 @@ import (
"sync"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
log "github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
@@ -55,6 +57,7 @@ type SqlStore struct {
metrics telemetry.AppMetrics
installationPK int
storeEngine types.Engine
pool *pgxpool.Pool
}
type installation struct {
@@ -774,6 +777,560 @@ func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types.
}
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
var account types.Account
account.Network = &types.Network{}
const accountQuery = `
SELECT
id, created_by, created_at, domain, domain_category, is_domain_primary_account,
network_identifier, network_net, network_dns, network_serial,
dns_settings_disabled_management_groups
FROM accounts WHERE id = $1`
var networkNet, dnsSettingsDisabledGroups []byte
err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan(
&account.Id, &account.CreatedBy, &account.CreatedAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount,
&account.Network.Identifier, &networkNet, &account.Network.Dns, &account.Network.Serial,
&dnsSettingsDisabledGroups,
)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, errors.New("account not found")
}
return nil, err
}
_ = json.Unmarshal(networkNet, &account.Network.Net)
_ = json.Unmarshal(dnsSettingsDisabledGroups, &account.DNSSettings.DisabledManagementGroups)
var wg sync.WaitGroup
errChan := make(chan error, 12)
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) {
var sk types.SetupKey
var autoGroups []byte
err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &sk.ExpiresAt, &sk.UpdatedAt, &sk.Revoked, &sk.UsedTimes, &sk.LastUsed, &autoGroups, &sk.UsageLimit, &sk.Ephemeral, &sk.AllowExtraDNSLabels)
if err == nil && autoGroups != nil {
_ = json.Unmarshal(autoGroups, &sk.AutoGroups)
}
if sk.UpdatedAt.IsZero() {
sk.UpdatedAt = sk.CreatedAt
}
return sk, err
})
if err != nil {
errChan <- err
return
}
account.SetupKeysG = keys
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, location_geo_name_id FROM peers WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) {
var p nbpeer.Peer
p.Status = &nbpeer.PeerStatus{}
var ip, extraDNS, netAddr, env, flags, files, connIP []byte
err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &p.SSHEnabled, &p.LoginExpirationEnabled, &p.InactivityExpirationEnabled, &p.LastLogin, &p.CreatedAt, &p.Ephemeral, &extraDNS, &p.AllowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &p.Status.LastSeen, &p.Status.Connected, &p.Status.LoginExpired, &p.Status.RequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID)
if err == nil {
if ip != nil {
_ = json.Unmarshal(ip, &p.IP)
}
if extraDNS != nil {
_ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels)
}
if netAddr != nil {
_ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses)
}
if env != nil {
_ = json.Unmarshal(env, &p.Meta.Environment)
}
if flags != nil {
_ = json.Unmarshal(flags, &p.Meta.Flags)
}
if files != nil {
_ = json.Unmarshal(files, &p.Meta.Files)
}
if connIP != nil {
_ = json.Unmarshal(connIP, &p.Location.ConnectionIP)
}
}
return p, err
})
if err != nil {
errChan <- err
return
}
account.PeersG = peers
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) {
var u types.User
var autoGroups []byte
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &u.IsServiceUser, &u.NonDeletable, &u.ServiceUserName, &autoGroups, &u.Blocked, &u.PendingApproval, &u.LastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType)
if err == nil && autoGroups != nil {
_ = json.Unmarshal(autoGroups, &u.AutoGroups)
}
return u, err
})
if err != nil {
errChan <- err
return
}
account.UsersG = users
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) {
var g types.Group
var resources []byte
err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &g.IntegrationReference.ID, &g.IntegrationReference.IntegrationType)
if err == nil && resources != nil {
_ = json.Unmarshal(resources, &g.Resources)
}
return &g, err
})
if err != nil {
errChan <- err
return
}
account.GroupsG = groups
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) {
var p types.Policy
var checks []byte
err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &p.Enabled, &checks)
if err == nil && checks != nil {
_ = json.Unmarshal(checks, &p.SourcePostureChecks)
}
return &p, err
})
if err != nil {
errChan <- err
return
}
account.Policies = policies
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) {
var r route.Route
var network, domains, peerGroups, groups, accessGroups []byte
err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &r.KeepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &r.Masquerade, &r.Metric, &r.Enabled, &groups, &accessGroups, &r.SkipAutoApply)
if err == nil {
if network != nil {
_ = json.Unmarshal(network, &r.Network)
}
if domains != nil {
_ = json.Unmarshal(domains, &r.Domains)
}
if peerGroups != nil {
_ = json.Unmarshal(peerGroups, &r.PeerGroups)
}
if groups != nil {
_ = json.Unmarshal(groups, &r.Groups)
}
if accessGroups != nil {
_ = json.Unmarshal(accessGroups, &r.AccessControlGroups)
}
}
return r, err
})
if err != nil {
errChan <- err
return
}
account.RoutesG = routes
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) {
var n nbdns.NameServerGroup
var ns, groups, domains []byte
err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &n.Primary, &domains, &n.Enabled, &n.SearchDomainsEnabled)
if err == nil {
if ns != nil {
_ = json.Unmarshal(ns, &n.NameServers)
}
if groups != nil {
_ = json.Unmarshal(groups, &n.Groups)
}
if domains != nil {
_ = json.Unmarshal(domains, &n.Domains)
}
}
return n, err
})
if err != nil {
errChan <- err
return
}
account.NameServerGroupsG = nsgs
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) {
var c posture.Checks
var checksDef []byte
err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef)
if err == nil && checksDef != nil {
_ = json.Unmarshal(checksDef, &c.Checks)
}
return &c, err
})
if err != nil {
errChan <- err
return
}
account.PostureChecks = checks
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network])
if err != nil {
errChan <- err
return
}
account.Networks = make([]*networkTypes.Network, len(networks))
for i := range networks {
account.Networks[i] = &networks[i]
}
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*routerTypes.NetworkRouter, error) {
var r routerTypes.NetworkRouter
var peerGroups []byte
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &r.Masquerade, &r.Metric, &r.Enabled)
if err == nil && peerGroups != nil {
_ = json.Unmarshal(peerGroups, &r.PeerGroups)
}
return &r, err
})
if err != nil {
errChan <- err
return
}
account.NetworkRouters = routers
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*resourceTypes.NetworkResource, error) {
var r resourceTypes.NetworkResource
var prefix []byte
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &r.Enabled)
if err == nil && prefix != nil {
_ = json.Unmarshal(prefix, &r.Prefix)
}
return &r, err
})
if err != nil {
errChan <- err
return
}
account.NetworkResources = resources
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1`
err := s.pool.QueryRow(ctx, query, accountID).Scan(
&account.Onboarding.AccountID,
&account.Onboarding.OnboardingFlowPending,
&account.Onboarding.SignupFormPending,
&account.Onboarding.CreatedAt,
&account.Onboarding.UpdatedAt,
)
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
errChan <- err
}
}()
wg.Wait()
close(errChan)
for e := range errChan {
if e != nil {
return nil, e
}
}
var userIDs []string
for _, u := range account.UsersG {
userIDs = append(userIDs, u.Id)
}
var policyIDs []string
for _, p := range account.Policies {
policyIDs = append(policyIDs, p.ID)
}
var groupIDs []string
for _, g := range account.GroupsG {
groupIDs = append(groupIDs, g.ID)
}
wg.Add(3)
errChan = make(chan error, 3)
var pats []types.PersonalAccessToken
go func() {
defer wg.Done()
if len(userIDs) == 0 {
return
}
const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)`
rows, err := s.pool.Query(ctx, query, userIDs)
if err != nil {
errChan <- err
return
}
pats, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.PersonalAccessToken])
if err != nil {
errChan <- err
}
}()
var rules []*types.PolicyRule
go func() {
defer wg.Done()
if len(policyIDs) == 0 {
return
}
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)`
rows, err := s.pool.Query(ctx, query, policyIDs)
if err != nil {
errChan <- err
return
}
rules, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) {
var r types.PolicyRule
var dest, destRes, sources, sourceRes, ports, portRanges []byte
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &r.Enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &r.Bidirectional, &r.Protocol, &ports, &portRanges)
if err == nil {
if dest != nil {
_ = json.Unmarshal(dest, &r.Destinations)
}
if destRes != nil {
_ = json.Unmarshal(destRes, &r.DestinationResource)
}
if sources != nil {
_ = json.Unmarshal(sources, &r.Sources)
}
if sourceRes != nil {
_ = json.Unmarshal(sourceRes, &r.SourceResource)
}
if ports != nil {
_ = json.Unmarshal(ports, &r.Ports)
}
if portRanges != nil {
_ = json.Unmarshal(portRanges, &r.PortRanges)
}
}
return &r, err
})
if err != nil {
errChan <- err
}
}()
var groupPeers []types.GroupPeer
go func() {
defer wg.Done()
if len(groupIDs) == 0 {
return
}
const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)`
rows, err := s.pool.Query(ctx, query, groupIDs)
if err != nil {
errChan <- err
return
}
groupPeers, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer])
if err != nil {
errChan <- err
}
}()
wg.Wait()
close(errChan)
for e := range errChan {
if e != nil {
return nil, e
}
}
patsByUserID := make(map[string][]*types.PersonalAccessToken)
for i := range pats {
pat := &pats[i]
patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat)
pat.UserID = ""
}
rulesByPolicyID := make(map[string][]*types.PolicyRule)
for _, rule := range rules {
rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule)
}
peersByGroupID := make(map[string][]string)
for _, gp := range groupPeers {
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
}
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
for i := range account.SetupKeysG {
key := &account.SetupKeysG[i]
account.SetupKeys[key.Key] = key
}
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for i := range account.PeersG {
peer := &account.PeersG[i]
account.Peers[peer.ID] = peer
}
account.Users = make(map[string]*types.User, len(account.UsersG))
for i := range account.UsersG {
user := &account.UsersG[i]
user.PATs = make(map[string]*types.PersonalAccessToken)
if userPats, ok := patsByUserID[user.Id]; ok {
for j := range userPats {
pat := userPats[j]
user.PATs[pat.ID] = pat
}
}
account.Users[user.Id] = user
}
for i := range account.Policies {
policy := account.Policies[i]
if policyRules, ok := rulesByPolicyID[policy.ID]; ok {
policy.Rules = policyRules
}
}
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for i := range account.GroupsG {
group := account.GroupsG[i]
if peerIDs, ok := peersByGroupID[group.ID]; ok {
group.Peers = peerIDs
}
account.Groups[group.ID] = group
}
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for i := range account.RoutesG {
route := &account.RoutesG[i]
account.Routes[route.ID] = route
}
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for i := range account.NameServerGroupsG {
nsg := &account.NameServerGroupsG[i]
nsg.AccountID = ""
account.NameServerGroups[nsg.ID] = nsg
}
account.SetupKeysG = nil
account.PeersG = nil
account.UsersG = nil
account.GroupsG = nil
account.RoutesG = nil
account.NameServerGroupsG = nil
return &account, nil
}
func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types.Account, error) {
start := time.Now()
defer func() {
elapsed := time.Since(start)
@@ -784,9 +1341,20 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
var account types.Account
result := s.db.Model(&account).
Omit("GroupsG").
// Omit("GroupsG").
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload(clause.Associations).
Preload("Policies.Rules").
Preload("SetupKeysG").
Preload("PeersG").
Preload("UsersG").
Preload("GroupsG.GroupPeers").
Preload("RoutesG").
Preload("NameServerGroupsG").
Preload("PostureChecks").
Preload("Networks").
Preload("NetworkRouters").
Preload("NetworkResources").
Preload("Onboarding").
Take(&account, idQueryCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
@@ -797,24 +1365,27 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
}
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
for i, policy := range account.Policies {
var rules []*types.PolicyRule
err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
if err != nil {
return nil, status.Errorf(status.NotFound, "rule not found")
}
account.Policies[i].Rules = rules
}
// for i, policy := range account.Policies {
// var rules []*types.PolicyRule
// err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
// if err != nil {
// return nil, status.Errorf(status.NotFound, "rule not found")
// }
// account.Policies[i].Rules = rules
// }
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
for _, key := range account.SetupKeysG {
account.SetupKeys[key.Key] = key.Copy()
if key.UpdatedAt.IsZero() {
key.UpdatedAt = key.CreatedAt
}
account.SetupKeys[key.Key] = &key
}
account.SetupKeysG = nil
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for _, peer := range account.PeersG {
account.Peers[peer.ID] = peer.Copy()
account.Peers[peer.ID] = &peer
}
account.PeersG = nil
@@ -822,38 +1393,45 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
for _, user := range account.UsersG {
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
pat.UserID = ""
user.PATs[pat.ID] = &pat
}
account.Users[user.Id] = user.Copy()
account.Users[user.Id] = &user
user.PATsG = nil
}
account.UsersG = nil
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for _, group := range account.GroupsG {
account.Groups[group.ID] = group.Copy()
group.Peers = make([]string, len(group.GroupPeers))
for i, gp := range group.GroupPeers {
group.Peers[i] = gp.PeerID
}
account.Groups[group.ID] = group
}
account.GroupsG = nil
var groupPeers []types.GroupPeer
s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID).
Find(&groupPeers)
for _, groupPeer := range groupPeers {
if group, ok := account.Groups[groupPeer.GroupID]; ok {
group.Peers = append(group.Peers, groupPeer.PeerID)
} else {
log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID)
}
}
// var groupPeers []types.GroupPeer
// s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID).
// Find(&groupPeers)
// for _, groupPeer := range groupPeers {
// if group, ok := account.Groups[groupPeer.GroupID]; ok {
// group.Peers = append(group.Peers, groupPeer.PeerID)
// } else {
// log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID)
// }
// }
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for _, route := range account.RoutesG {
account.Routes[route.ID] = route.Copy()
account.Routes[route.ID] = &route
}
account.RoutesG = nil
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for _, ns := range account.NameServerGroupsG {
account.NameServerGroups[ns.ID] = ns.Copy()
ns.AccountID = ""
account.NameServerGroups[ns.ID] = &ns
}
account.NameServerGroupsG = nil
@@ -1199,8 +1777,42 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
if err != nil {
return nil, err
}
pool, err := connectDB(context.Background(), dsn)
if err != nil {
return nil, err
}
store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration)
if err != nil {
pool.Close()
return nil, err
}
store.pool = pool
return store, nil
}
return NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration)
func connectDB(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
config, err := pgxpool.ParseConfig(dsn)
if err != nil {
return nil, fmt.Errorf("unable to parse database config: %w", err)
}
config.MaxConns = 10
config.MinConns = 2
config.MaxConnLifetime = time.Hour
config.HealthCheckPeriod = time.Minute
pool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return nil, fmt.Errorf("unable to create connection pool: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("unable to ping database: %w", err)
}
fmt.Println("Successfully connected to the database!")
return pool, nil
}
// NewMysqlStore creates a new MySQL store.

View File

@@ -124,6 +124,7 @@ func (r *Route) EventMeta() map[string]any {
func (r *Route) Copy() *Route {
route := &Route{
ID: r.ID,
AccountID: r.AccountID,
Description: r.Description,
NetID: r.NetID,
Network: r.Network,