mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 09:16:40 +00:00
code cleanup
This commit is contained in:
@@ -2,7 +2,6 @@ package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -16,7 +15,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/mysql"
|
||||
@@ -778,852 +776,6 @@ func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types.
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
|
||||
var account types.Account
|
||||
account.Network = &types.Network{}
|
||||
const accountQuery = `
|
||||
SELECT
|
||||
id, created_by, created_at, domain, domain_category, is_domain_primary_account,
|
||||
-- Embedded Network
|
||||
network_identifier, network_net, network_dns, network_serial,
|
||||
-- Embedded DNSSettings
|
||||
dns_settings_disabled_management_groups,
|
||||
-- Embedded Settings
|
||||
settings_peer_login_expiration_enabled, settings_peer_login_expiration,
|
||||
settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration,
|
||||
settings_regular_users_view_blocked, settings_groups_propagation_enabled,
|
||||
settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups,
|
||||
settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range,
|
||||
settings_lazy_connection_enabled,
|
||||
-- Embedded ExtraSettings
|
||||
settings_extra_peer_approval_enabled, settings_extra_user_approval_required,
|
||||
settings_extra_integrated_validator, settings_extra_integrated_validator_groups
|
||||
FROM accounts WHERE id = $1`
|
||||
|
||||
var networkNet, dnsSettingsDisabledGroups []byte
|
||||
var (
|
||||
sPeerLoginExpirationEnabled sql.NullBool
|
||||
sPeerLoginExpiration sql.NullInt64
|
||||
sPeerInactivityExpirationEnabled sql.NullBool
|
||||
sPeerInactivityExpiration sql.NullInt64
|
||||
sRegularUsersViewBlocked sql.NullBool
|
||||
sGroupsPropagationEnabled sql.NullBool
|
||||
sJWTGroupsEnabled sql.NullBool
|
||||
sJWTGroupsClaimName sql.NullString
|
||||
sJWTAllowGroups []byte
|
||||
sRoutingPeerDNSResolutionEnabled sql.NullBool
|
||||
sDNSDomain sql.NullString
|
||||
sNetworkRange []byte
|
||||
sLazyConnectionEnabled sql.NullBool
|
||||
sExtraPeerApprovalEnabled sql.NullBool
|
||||
sExtraUserApprovalRequired sql.NullBool
|
||||
sExtraIntegratedValidator sql.NullString
|
||||
sExtraIntegratedValidatorGroups []byte
|
||||
)
|
||||
|
||||
err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan(
|
||||
&account.Id, &account.CreatedBy, &account.CreatedAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount,
|
||||
&account.Network.Identifier, &networkNet, &account.Network.Dns, &account.Network.Serial,
|
||||
&dnsSettingsDisabledGroups,
|
||||
&sPeerLoginExpirationEnabled, &sPeerLoginExpiration,
|
||||
&sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration,
|
||||
&sRegularUsersViewBlocked, &sGroupsPropagationEnabled,
|
||||
&sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups,
|
||||
&sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange,
|
||||
&sLazyConnectionEnabled,
|
||||
&sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired,
|
||||
&sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = json.Unmarshal(networkNet, &account.Network.Net)
|
||||
_ = json.Unmarshal(dnsSettingsDisabledGroups, &account.DNSSettings.DisabledManagementGroups)
|
||||
|
||||
account.Settings = &types.Settings{Extra: &types.ExtraSettings{}}
|
||||
if sPeerLoginExpirationEnabled.Valid {
|
||||
account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool
|
||||
}
|
||||
if sPeerLoginExpiration.Valid {
|
||||
account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64)
|
||||
}
|
||||
if sPeerInactivityExpirationEnabled.Valid {
|
||||
account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool
|
||||
}
|
||||
if sPeerInactivityExpiration.Valid {
|
||||
account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64)
|
||||
}
|
||||
if sRegularUsersViewBlocked.Valid {
|
||||
account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool
|
||||
}
|
||||
if sGroupsPropagationEnabled.Valid {
|
||||
account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool
|
||||
}
|
||||
if sJWTGroupsEnabled.Valid {
|
||||
account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool
|
||||
}
|
||||
if sJWTGroupsClaimName.Valid {
|
||||
account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String
|
||||
}
|
||||
if sRoutingPeerDNSResolutionEnabled.Valid {
|
||||
account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool
|
||||
}
|
||||
if sDNSDomain.Valid {
|
||||
account.Settings.DNSDomain = sDNSDomain.String
|
||||
}
|
||||
if sLazyConnectionEnabled.Valid {
|
||||
account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool
|
||||
}
|
||||
if sJWTAllowGroups != nil {
|
||||
_ = json.Unmarshal(sJWTAllowGroups, &account.Settings.JWTAllowGroups)
|
||||
}
|
||||
if sNetworkRange != nil {
|
||||
_ = json.Unmarshal(sNetworkRange, &account.Settings.NetworkRange)
|
||||
}
|
||||
|
||||
if sExtraPeerApprovalEnabled.Valid {
|
||||
account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool
|
||||
}
|
||||
if sExtraUserApprovalRequired.Valid {
|
||||
account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool
|
||||
}
|
||||
if sExtraIntegratedValidator.Valid {
|
||||
account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String
|
||||
}
|
||||
if sExtraIntegratedValidatorGroups != nil {
|
||||
_ = json.Unmarshal(sExtraIntegratedValidatorGroups, &account.Settings.Extra.IntegratedValidatorGroups)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 12)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) {
|
||||
var sk types.SetupKey
|
||||
var autoGroups []byte
|
||||
var expiresAt, updatedAt, lastUsed sql.NullTime
|
||||
var revoked, ephemeral, allowExtraDNSLabels sql.NullBool
|
||||
var usedTimes, usageLimit sql.NullInt64
|
||||
|
||||
err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels)
|
||||
|
||||
if err == nil {
|
||||
if expiresAt.Valid {
|
||||
sk.ExpiresAt = &expiresAt.Time
|
||||
}
|
||||
if updatedAt.Valid {
|
||||
sk.UpdatedAt = updatedAt.Time
|
||||
if sk.UpdatedAt.IsZero() {
|
||||
sk.UpdatedAt = sk.CreatedAt
|
||||
}
|
||||
}
|
||||
if lastUsed.Valid {
|
||||
sk.LastUsed = &lastUsed.Time
|
||||
}
|
||||
if revoked.Valid {
|
||||
sk.Revoked = revoked.Bool
|
||||
}
|
||||
if usedTimes.Valid {
|
||||
sk.UsedTimes = int(usedTimes.Int64)
|
||||
}
|
||||
if usageLimit.Valid {
|
||||
sk.UsageLimit = int(usageLimit.Int64)
|
||||
}
|
||||
if ephemeral.Valid {
|
||||
sk.Ephemeral = ephemeral.Bool
|
||||
}
|
||||
if allowExtraDNSLabels.Valid {
|
||||
sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool
|
||||
}
|
||||
if autoGroups != nil {
|
||||
_ = json.Unmarshal(autoGroups, &sk.AutoGroups)
|
||||
} else {
|
||||
sk.AutoGroups = []string{}
|
||||
}
|
||||
}
|
||||
return sk, err
|
||||
})
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.SetupKeysG = keys
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, location_geo_name_id FROM peers WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) {
|
||||
var p nbpeer.Peer
|
||||
p.Status = &nbpeer.PeerStatus{}
|
||||
var lastLogin sql.NullTime
|
||||
var sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool
|
||||
var peerStatusLastSeen sql.NullTime
|
||||
var peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool
|
||||
var ip, extraDNS, netAddr, env, flags, files, connIP []byte
|
||||
|
||||
err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &p.CreatedAt, &ephemeral, &extraDNS, &allowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID)
|
||||
|
||||
if err == nil {
|
||||
if lastLogin.Valid {
|
||||
p.LastLogin = &lastLogin.Time
|
||||
}
|
||||
if sshEnabled.Valid {
|
||||
p.SSHEnabled = sshEnabled.Bool
|
||||
}
|
||||
if loginExpirationEnabled.Valid {
|
||||
p.LoginExpirationEnabled = loginExpirationEnabled.Bool
|
||||
}
|
||||
if inactivityExpirationEnabled.Valid {
|
||||
p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool
|
||||
}
|
||||
if ephemeral.Valid {
|
||||
p.Ephemeral = ephemeral.Bool
|
||||
}
|
||||
if allowExtraDNSLabels.Valid {
|
||||
p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool
|
||||
}
|
||||
if peerStatusLastSeen.Valid {
|
||||
p.Status.LastSeen = peerStatusLastSeen.Time
|
||||
}
|
||||
if peerStatusConnected.Valid {
|
||||
p.Status.Connected = peerStatusConnected.Bool
|
||||
}
|
||||
if peerStatusLoginExpired.Valid {
|
||||
p.Status.LoginExpired = peerStatusLoginExpired.Bool
|
||||
}
|
||||
if peerStatusRequiresApproval.Valid {
|
||||
p.Status.RequiresApproval = peerStatusRequiresApproval.Bool
|
||||
}
|
||||
|
||||
if ip != nil {
|
||||
_ = json.Unmarshal(ip, &p.IP)
|
||||
}
|
||||
if extraDNS != nil {
|
||||
_ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels)
|
||||
}
|
||||
if netAddr != nil {
|
||||
_ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses)
|
||||
}
|
||||
if env != nil {
|
||||
_ = json.Unmarshal(env, &p.Meta.Environment)
|
||||
}
|
||||
if flags != nil {
|
||||
_ = json.Unmarshal(flags, &p.Meta.Flags)
|
||||
}
|
||||
if files != nil {
|
||||
_ = json.Unmarshal(files, &p.Meta.Files)
|
||||
}
|
||||
if connIP != nil {
|
||||
_ = json.Unmarshal(connIP, &p.Location.ConnectionIP)
|
||||
}
|
||||
}
|
||||
return p, err
|
||||
})
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.PeersG = peers
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) {
|
||||
var u types.User
|
||||
var autoGroups []byte
|
||||
var lastLogin sql.NullTime
|
||||
var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
|
||||
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType)
|
||||
if err == nil {
|
||||
if lastLogin.Valid {
|
||||
u.LastLogin = &lastLogin.Time
|
||||
}
|
||||
if isServiceUser.Valid {
|
||||
u.IsServiceUser = isServiceUser.Bool
|
||||
}
|
||||
if nonDeletable.Valid {
|
||||
u.NonDeletable = nonDeletable.Bool
|
||||
}
|
||||
if blocked.Valid {
|
||||
u.Blocked = blocked.Bool
|
||||
}
|
||||
if pendingApproval.Valid {
|
||||
u.PendingApproval = pendingApproval.Bool
|
||||
}
|
||||
if autoGroups != nil {
|
||||
_ = json.Unmarshal(autoGroups, &u.AutoGroups)
|
||||
} else {
|
||||
u.AutoGroups = []string{}
|
||||
}
|
||||
}
|
||||
return u, err
|
||||
})
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.UsersG = users
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) {
|
||||
var g types.Group
|
||||
var resources []byte
|
||||
var refID sql.NullInt64
|
||||
var refType sql.NullString
|
||||
err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType)
|
||||
if err == nil {
|
||||
if refID.Valid {
|
||||
g.IntegrationReference.ID = int(refID.Int64)
|
||||
}
|
||||
if refType.Valid {
|
||||
g.IntegrationReference.IntegrationType = refType.String
|
||||
}
|
||||
if resources != nil {
|
||||
_ = json.Unmarshal(resources, &g.Resources)
|
||||
} else {
|
||||
g.Resources = []types.Resource{}
|
||||
}
|
||||
g.GroupPeers = []types.GroupPeer{}
|
||||
g.Peers = []string{}
|
||||
}
|
||||
return &g, err
|
||||
})
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.GroupsG = groups
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) {
|
||||
var p types.Policy
|
||||
var checks []byte
|
||||
var enabled sql.NullBool
|
||||
err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks)
|
||||
if err == nil {
|
||||
if enabled.Valid {
|
||||
p.Enabled = enabled.Bool
|
||||
}
|
||||
if checks != nil {
|
||||
_ = json.Unmarshal(checks, &p.SourcePostureChecks)
|
||||
}
|
||||
}
|
||||
return &p, err
|
||||
})
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.Policies = policies
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) {
|
||||
var r route.Route
|
||||
var network, domains, peerGroups, groups, accessGroups []byte
|
||||
var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool
|
||||
var metric sql.NullInt64
|
||||
err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply)
|
||||
if err == nil {
|
||||
if keepRoute.Valid {
|
||||
r.KeepRoute = keepRoute.Bool
|
||||
}
|
||||
if masquerade.Valid {
|
||||
r.Masquerade = masquerade.Bool
|
||||
}
|
||||
if enabled.Valid {
|
||||
r.Enabled = enabled.Bool
|
||||
}
|
||||
if skipAutoApply.Valid {
|
||||
r.SkipAutoApply = skipAutoApply.Bool
|
||||
}
|
||||
if metric.Valid {
|
||||
r.Metric = int(metric.Int64)
|
||||
}
|
||||
if network != nil {
|
||||
_ = json.Unmarshal(network, &r.Network)
|
||||
}
|
||||
if domains != nil {
|
||||
_ = json.Unmarshal(domains, &r.Domains)
|
||||
}
|
||||
if peerGroups != nil {
|
||||
_ = json.Unmarshal(peerGroups, &r.PeerGroups)
|
||||
}
|
||||
if groups != nil {
|
||||
_ = json.Unmarshal(groups, &r.Groups)
|
||||
}
|
||||
if accessGroups != nil {
|
||||
_ = json.Unmarshal(accessGroups, &r.AccessControlGroups)
|
||||
}
|
||||
}
|
||||
return r, err
|
||||
})
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.RoutesG = routes
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) {
|
||||
var n nbdns.NameServerGroup
|
||||
var ns, groups, domains []byte
|
||||
var primary, enabled, searchDomainsEnabled sql.NullBool
|
||||
err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled)
|
||||
if err == nil {
|
||||
if primary.Valid {
|
||||
n.Primary = primary.Bool
|
||||
}
|
||||
if enabled.Valid {
|
||||
n.Enabled = enabled.Bool
|
||||
}
|
||||
if searchDomainsEnabled.Valid {
|
||||
n.SearchDomainsEnabled = searchDomainsEnabled.Bool
|
||||
}
|
||||
if ns != nil {
|
||||
_ = json.Unmarshal(ns, &n.NameServers)
|
||||
} else {
|
||||
n.NameServers = []nbdns.NameServer{}
|
||||
}
|
||||
if groups != nil {
|
||||
_ = json.Unmarshal(groups, &n.Groups)
|
||||
} else {
|
||||
n.Groups = []string{}
|
||||
}
|
||||
if domains != nil {
|
||||
_ = json.Unmarshal(domains, &n.Domains)
|
||||
} else {
|
||||
n.Domains = []string{}
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
})
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.NameServerGroupsG = nsgs
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) {
|
||||
var c posture.Checks
|
||||
var checksDef []byte
|
||||
err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef)
|
||||
if err == nil && checksDef != nil {
|
||||
_ = json.Unmarshal(checksDef, &c.Checks)
|
||||
}
|
||||
return &c, err
|
||||
})
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.PostureChecks = checks
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network])
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.Networks = make([]*networkTypes.Network, len(networks))
|
||||
for i := range networks {
|
||||
account.Networks[i] = &networks[i]
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) {
|
||||
var r routerTypes.NetworkRouter
|
||||
var peerGroups []byte
|
||||
var masquerade, enabled sql.NullBool
|
||||
var metric sql.NullInt64
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled)
|
||||
if err == nil {
|
||||
if masquerade.Valid {
|
||||
r.Masquerade = masquerade.Bool
|
||||
}
|
||||
if enabled.Valid {
|
||||
r.Enabled = enabled.Bool
|
||||
}
|
||||
if metric.Valid {
|
||||
r.Metric = int(metric.Int64)
|
||||
}
|
||||
if peerGroups != nil {
|
||||
_ = json.Unmarshal(peerGroups, &r.PeerGroups)
|
||||
}
|
||||
}
|
||||
return r, err
|
||||
})
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.NetworkRouters = make([]*routerTypes.NetworkRouter, len(routers))
|
||||
for i := range routers {
|
||||
account.NetworkRouters[i] = &routers[i]
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) {
|
||||
var r resourceTypes.NetworkResource
|
||||
var prefix []byte
|
||||
var enabled sql.NullBool
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled)
|
||||
if err == nil {
|
||||
if enabled.Valid {
|
||||
r.Enabled = enabled.Bool
|
||||
}
|
||||
if prefix != nil {
|
||||
_ = json.Unmarshal(prefix, &r.Prefix)
|
||||
}
|
||||
}
|
||||
return r, err
|
||||
})
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.NetworkResources = make([]*resourceTypes.NetworkResource, len(resources))
|
||||
for i := range resources {
|
||||
account.NetworkResources[i] = &resources[i]
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1`
|
||||
var onboardingFlowPending, signupFormPending sql.NullBool
|
||||
err := s.pool.QueryRow(ctx, query, accountID).Scan(
|
||||
&account.Onboarding.AccountID,
|
||||
&onboardingFlowPending,
|
||||
&signupFormPending,
|
||||
&account.Onboarding.CreatedAt,
|
||||
&account.Onboarding.UpdatedAt,
|
||||
)
|
||||
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
if onboardingFlowPending.Valid {
|
||||
account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool
|
||||
}
|
||||
if signupFormPending.Valid {
|
||||
account.Onboarding.SignupFormPending = signupFormPending.Bool
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
for e := range errChan {
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
}
|
||||
|
||||
var userIDs []string
|
||||
for _, u := range account.UsersG {
|
||||
userIDs = append(userIDs, u.Id)
|
||||
}
|
||||
var policyIDs []string
|
||||
for _, p := range account.Policies {
|
||||
policyIDs = append(policyIDs, p.ID)
|
||||
}
|
||||
var groupIDs []string
|
||||
for _, g := range account.GroupsG {
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
|
||||
wg.Add(3)
|
||||
errChan = make(chan error, 3)
|
||||
|
||||
var pats []types.PersonalAccessToken
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if len(userIDs) == 0 {
|
||||
return
|
||||
}
|
||||
const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)`
|
||||
rows, err := s.pool.Query(ctx, query, userIDs)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
pats, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) {
|
||||
var pat types.PersonalAccessToken
|
||||
var expirationDate, lastUsed sql.NullTime
|
||||
err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &pat.CreatedAt, &lastUsed)
|
||||
if err == nil {
|
||||
if expirationDate.Valid {
|
||||
pat.ExpirationDate = &expirationDate.Time
|
||||
}
|
||||
if lastUsed.Valid {
|
||||
pat.LastUsed = &lastUsed.Time
|
||||
}
|
||||
}
|
||||
return pat, err
|
||||
})
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
var rules []*types.PolicyRule
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if len(policyIDs) == 0 {
|
||||
return
|
||||
}
|
||||
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)`
|
||||
rows, err := s.pool.Query(ctx, query, policyIDs)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
rules, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) {
|
||||
var r types.PolicyRule
|
||||
var dest, destRes, sources, sourceRes, ports, portRanges []byte
|
||||
var enabled, bidirectional sql.NullBool
|
||||
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges)
|
||||
if err == nil {
|
||||
if enabled.Valid {
|
||||
r.Enabled = enabled.Bool
|
||||
}
|
||||
if bidirectional.Valid {
|
||||
r.Bidirectional = bidirectional.Bool
|
||||
}
|
||||
if dest != nil {
|
||||
_ = json.Unmarshal(dest, &r.Destinations)
|
||||
}
|
||||
if destRes != nil {
|
||||
_ = json.Unmarshal(destRes, &r.DestinationResource)
|
||||
}
|
||||
if sources != nil {
|
||||
_ = json.Unmarshal(sources, &r.Sources)
|
||||
}
|
||||
if sourceRes != nil {
|
||||
_ = json.Unmarshal(sourceRes, &r.SourceResource)
|
||||
}
|
||||
if ports != nil {
|
||||
_ = json.Unmarshal(ports, &r.Ports)
|
||||
}
|
||||
if portRanges != nil {
|
||||
_ = json.Unmarshal(portRanges, &r.PortRanges)
|
||||
}
|
||||
}
|
||||
return &r, err
|
||||
})
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
var groupPeers []types.GroupPeer
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if len(groupIDs) == 0 {
|
||||
return
|
||||
}
|
||||
const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)`
|
||||
rows, err := s.pool.Query(ctx, query, groupIDs)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
groupPeers, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer])
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
for e := range errChan {
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
}
|
||||
|
||||
patsByUserID := make(map[string][]*types.PersonalAccessToken)
|
||||
for i := range pats {
|
||||
pat := &pats[i]
|
||||
patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat)
|
||||
pat.UserID = ""
|
||||
}
|
||||
|
||||
rulesByPolicyID := make(map[string][]*types.PolicyRule)
|
||||
for _, rule := range rules {
|
||||
rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule)
|
||||
}
|
||||
|
||||
peersByGroupID := make(map[string][]string)
|
||||
for _, gp := range groupPeers {
|
||||
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
|
||||
}
|
||||
|
||||
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
|
||||
for i := range account.SetupKeysG {
|
||||
key := &account.SetupKeysG[i]
|
||||
account.SetupKeys[key.Key] = key
|
||||
}
|
||||
|
||||
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
|
||||
for i := range account.PeersG {
|
||||
peer := &account.PeersG[i]
|
||||
account.Peers[peer.ID] = peer
|
||||
}
|
||||
|
||||
account.Users = make(map[string]*types.User, len(account.UsersG))
|
||||
for i := range account.UsersG {
|
||||
user := &account.UsersG[i]
|
||||
user.PATs = make(map[string]*types.PersonalAccessToken)
|
||||
if userPats, ok := patsByUserID[user.Id]; ok {
|
||||
for j := range userPats {
|
||||
pat := userPats[j]
|
||||
user.PATs[pat.ID] = pat
|
||||
}
|
||||
}
|
||||
account.Users[user.Id] = user
|
||||
}
|
||||
|
||||
for i := range account.Policies {
|
||||
policy := account.Policies[i]
|
||||
if policyRules, ok := rulesByPolicyID[policy.ID]; ok {
|
||||
policy.Rules = policyRules
|
||||
}
|
||||
}
|
||||
|
||||
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
|
||||
for i := range account.GroupsG {
|
||||
group := account.GroupsG[i]
|
||||
if peerIDs, ok := peersByGroupID[group.ID]; ok {
|
||||
group.Peers = peerIDs
|
||||
}
|
||||
account.Groups[group.ID] = group
|
||||
}
|
||||
|
||||
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
|
||||
for i := range account.RoutesG {
|
||||
route := &account.RoutesG[i]
|
||||
account.Routes[route.ID] = route
|
||||
}
|
||||
|
||||
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
|
||||
for i := range account.NameServerGroupsG {
|
||||
nsg := &account.NameServerGroupsG[i]
|
||||
nsg.AccountID = ""
|
||||
account.NameServerGroups[nsg.ID] = nsg
|
||||
}
|
||||
|
||||
account.SetupKeysG = nil
|
||||
account.PeersG = nil
|
||||
account.UsersG = nil
|
||||
account.GroupsG = nil
|
||||
account.RoutesG = nil
|
||||
account.NameServerGroupsG = nil
|
||||
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types.Account, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
elapsed := time.Since(start)
|
||||
@@ -1634,8 +786,7 @@ func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types.
|
||||
|
||||
var account types.Account
|
||||
result := s.db.Model(&account).
|
||||
// Omit("GroupsG").
|
||||
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
|
||||
Preload("UsersG.PATsG").
|
||||
Preload("Policies.Rules").
|
||||
Preload("SetupKeysG").
|
||||
Preload("PeersG").
|
||||
@@ -1657,21 +808,14 @@ func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types.
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
|
||||
// for i, policy := range account.Policies {
|
||||
// var rules []*types.PolicyRule
|
||||
// err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
// if err != nil {
|
||||
// return nil, status.Errorf(status.NotFound, "rule not found")
|
||||
// }
|
||||
// account.Policies[i].Rules = rules
|
||||
// }
|
||||
|
||||
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
|
||||
for _, key := range account.SetupKeysG {
|
||||
if key.UpdatedAt.IsZero() {
|
||||
key.UpdatedAt = key.CreatedAt
|
||||
}
|
||||
if key.AutoGroups == nil {
|
||||
key.AutoGroups = []string{}
|
||||
}
|
||||
account.SetupKeys[key.Key] = &key
|
||||
}
|
||||
account.SetupKeysG = nil
|
||||
@@ -1689,6 +833,9 @@ func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types.
|
||||
pat.UserID = ""
|
||||
user.PATs[pat.ID] = &pat
|
||||
}
|
||||
if user.AutoGroups == nil {
|
||||
user.AutoGroups = []string{}
|
||||
}
|
||||
account.Users[user.Id] = &user
|
||||
user.PATsG = nil
|
||||
}
|
||||
@@ -1700,21 +847,13 @@ func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types.
|
||||
for i, gp := range group.GroupPeers {
|
||||
group.Peers[i] = gp.PeerID
|
||||
}
|
||||
if group.Resources == nil {
|
||||
group.Resources = []types.Resource{}
|
||||
}
|
||||
account.Groups[group.ID] = group
|
||||
}
|
||||
account.GroupsG = nil
|
||||
|
||||
// var groupPeers []types.GroupPeer
|
||||
// s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID).
|
||||
// Find(&groupPeers)
|
||||
// for _, groupPeer := range groupPeers {
|
||||
// if group, ok := account.Groups[groupPeer.GroupID]; ok {
|
||||
// group.Peers = append(group.Peers, groupPeer.PeerID)
|
||||
// } else {
|
||||
// log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID)
|
||||
// }
|
||||
// }
|
||||
|
||||
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
|
||||
for _, route := range account.RoutesG {
|
||||
account.Routes[route.ID] = &route
|
||||
@@ -1724,6 +863,15 @@ func (s *SqlStore) getAccountOld(ctx context.Context, accountID string) (*types.
|
||||
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
|
||||
for _, ns := range account.NameServerGroupsG {
|
||||
ns.AccountID = ""
|
||||
if ns.NameServers == nil {
|
||||
ns.NameServers = []nbdns.NameServer{}
|
||||
}
|
||||
if ns.Groups == nil {
|
||||
ns.Groups = []string{}
|
||||
}
|
||||
if ns.Domains == nil {
|
||||
ns.Domains = []string{}
|
||||
}
|
||||
account.NameServerGroups[ns.ID] = &ns
|
||||
}
|
||||
account.NameServerGroupsG = nil
|
||||
@@ -2070,7 +1218,7 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pool, err := connectDB(context.Background(), dsn)
|
||||
pool, err := connectToPgDb(context.Background(), dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2083,7 +1231,7 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func connectDB(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
|
||||
func connectToPgDb(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
|
||||
config, err := pgxpool.ParseConfig(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse database config: %w", err)
|
||||
|
||||
Reference in New Issue
Block a user