Files
netbird/management/server/store/sqlstore_bench_test.go
2025-10-19 17:03:36 +02:00

1494 lines
48 KiB
Go

package store
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net"
"net/netip"
"sort"
"sync"
"testing"
"time"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/status"
)
func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types.Account, error) {
start := time.Now()
defer func() {
elapsed := time.Since(start)
if elapsed > 1*time.Second {
log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed)
}
}()
var account types.Account
result := s.db.Model(&account).
Omit("GroupsG").
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload(clause.Associations).
Take(&account, idQueryCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.NewGetAccountFromStoreError(result.Error)
}
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
for i, policy := range account.Policies {
var rules []*types.PolicyRule
err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
if err != nil {
return nil, status.Errorf(status.NotFound, "rule not found")
}
account.Policies[i].Rules = rules
}
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
for _, key := range account.SetupKeysG {
account.SetupKeys[key.Key] = key.Copy()
}
account.SetupKeysG = nil
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for _, peer := range account.PeersG {
account.Peers[peer.ID] = peer.Copy()
}
account.PeersG = nil
account.Users = make(map[string]*types.User, len(account.UsersG))
for _, user := range account.UsersG {
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
}
account.Users[user.Id] = user.Copy()
}
account.UsersG = nil
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for _, group := range account.GroupsG {
account.Groups[group.ID] = group.Copy()
}
account.GroupsG = nil
var groupPeers []types.GroupPeer
s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID).
Find(&groupPeers)
for _, groupPeer := range groupPeers {
if group, ok := account.Groups[groupPeer.GroupID]; ok {
group.Peers = append(group.Peers, groupPeer.PeerID)
} else {
log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID)
}
}
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for _, route := range account.RoutesG {
account.Routes[route.ID] = route.Copy()
}
account.RoutesG = nil
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for _, ns := range account.NameServerGroupsG {
account.NameServerGroups[ns.ID] = ns.Copy()
}
account.NameServerGroupsG = nil
return &account, nil
}
func connectDBforTest(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
config, err := pgxpool.ParseConfig(dsn)
if err != nil {
return nil, fmt.Errorf("unable to parse database config: %w", err)
}
config.MaxConns = 10
config.MinConns = 2
config.MaxConnLifetime = time.Hour
config.HealthCheckPeriod = time.Minute
pool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return nil, fmt.Errorf("unable to create connection pool: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("unable to ping database: %w", err)
}
fmt.Println("Successfully connected to the database!")
return pool, nil
}
func setupBenchmarkDB(b testing.TB) (*SqlStore, string) {
dsn := "host=localhost user=postgres password=mysecretpassword dbname=testdb port=5432 sslmode=disable TimeZone=Europe/Berlin"
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
b.Fatalf("failed to connect database: %v", err)
}
pool, err := connectDB(context.Background(), dsn)
if err != nil {
b.Fatalf("failed to connect database: %v", err)
}
models := []interface{}{
&types.Account{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{},
&types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
&types.AccountOnboarding{},
}
for i := len(models) - 1; i >= 0; i-- {
db.Migrator().DropTable(models[i])
}
err = db.AutoMigrate(models...)
if err != nil {
b.Fatalf("failed to migrate database: %v", err)
}
store := &SqlStore{
db: db,
pool: pool,
}
const (
accountID = "benchmark-account-id"
numUsers = 20
numPatsPerUser = 3
numSetupKeys = 25
numPeers = 200
numGroups = 30
numPolicies = 50
numRulesPerPolicy = 10
numRoutes = 40
numNSGroups = 10
numPostureChecks = 15
numNetworks = 5
numNetworkRouters = 5
numNetworkResources = 10
)
_, ipNet, _ := net.ParseCIDR("100.64.0.0/10")
acc := types.Account{
Id: accountID,
CreatedBy: "benchmark-user",
CreatedAt: time.Now(),
Domain: "benchmark.com",
IsDomainPrimaryAccount: true,
Network: &types.Network{
Identifier: "benchmark-net",
Net: *ipNet,
Serial: 1,
},
DNSSettings: types.DNSSettings{
DisabledManagementGroups: []string{"group-disabled-1"},
},
Settings: &types.Settings{},
}
if err := db.Create(&acc).Error; err != nil {
b.Fatalf("create account: %v", err)
}
var setupKeys []types.SetupKey
for i := 0; i < numSetupKeys; i++ {
setupKeys = append(setupKeys, types.SetupKey{
Id: fmt.Sprintf("keyid-%d", i),
AccountID: accountID,
Key: fmt.Sprintf("key-%d", i),
Name: fmt.Sprintf("Benchmark Key %d", i),
ExpiresAt: &time.Time{},
})
}
if err := db.Create(&setupKeys).Error; err != nil {
b.Fatalf("create setup keys: %v", err)
}
var peers []nbpeer.Peer
for i := 0; i < numPeers; i++ {
peers = append(peers, nbpeer.Peer{
ID: fmt.Sprintf("peer-%d", i),
AccountID: accountID,
Key: fmt.Sprintf("peerkey-%d", i),
IP: net.ParseIP(fmt.Sprintf("100.64.0.%d", i+1)),
Name: fmt.Sprintf("peer-name-%d", i),
Status: &nbpeer.PeerStatus{Connected: i%2 == 0, LastSeen: time.Now()},
})
}
if err := db.Create(&peers).Error; err != nil {
b.Fatalf("create peers: %v", err)
}
for i := 0; i < numUsers; i++ {
userID := fmt.Sprintf("user-%d", i)
user := types.User{Id: userID, AccountID: accountID}
if err := db.Create(&user).Error; err != nil {
b.Fatalf("create user %s: %v", userID, err)
}
var pats []types.PersonalAccessToken
for j := 0; j < numPatsPerUser; j++ {
pats = append(pats, types.PersonalAccessToken{
ID: fmt.Sprintf("pat-%d-%d", i, j),
UserID: userID,
Name: fmt.Sprintf("PAT %d for User %d", j, i),
})
}
if err := db.Create(&pats).Error; err != nil {
b.Fatalf("create pats for user %s: %v", userID, err)
}
}
var groups []*types.Group
for i := 0; i < numGroups; i++ {
groups = append(groups, &types.Group{
ID: fmt.Sprintf("group-%d", i),
AccountID: accountID,
Name: fmt.Sprintf("Group %d", i),
})
}
if err := db.Create(&groups).Error; err != nil {
b.Fatalf("create groups: %v", err)
}
for i := 0; i < numPolicies; i++ {
policyID := fmt.Sprintf("policy-%d", i)
policy := types.Policy{ID: policyID, AccountID: accountID, Name: fmt.Sprintf("Policy %d", i), Enabled: true}
if err := db.Create(&policy).Error; err != nil {
b.Fatalf("create policy %s: %v", policyID, err)
}
var rules []*types.PolicyRule
for j := 0; j < numRulesPerPolicy; j++ {
rules = append(rules, &types.PolicyRule{
ID: fmt.Sprintf("rule-%d-%d", i, j),
PolicyID: policyID,
Name: fmt.Sprintf("Rule %d for Policy %d", j, i),
Enabled: true,
Protocol: "all",
})
}
if err := db.Create(&rules).Error; err != nil {
b.Fatalf("create rules for policy %s: %v", policyID, err)
}
}
var routes []route.Route
for i := 0; i < numRoutes; i++ {
routes = append(routes, route.Route{
ID: route.ID(fmt.Sprintf("route-%d", i)),
AccountID: accountID,
Description: fmt.Sprintf("Route %d", i),
Network: netip.MustParsePrefix(fmt.Sprintf("192.168.%d.0/24", i)),
Enabled: true,
})
}
if err := db.Create(&routes).Error; err != nil {
b.Fatalf("create routes: %v", err)
}
var nsGroups []nbdns.NameServerGroup
for i := 0; i < numNSGroups; i++ {
nsGroups = append(nsGroups, nbdns.NameServerGroup{
ID: fmt.Sprintf("nsg-%d", i),
AccountID: accountID,
Name: fmt.Sprintf("NS Group %d", i),
Description: "Benchmark NS Group",
Enabled: true,
})
}
if err := db.Create(&nsGroups).Error; err != nil {
b.Fatalf("create nsgroups: %v", err)
}
var postureChecks []*posture.Checks
for i := 0; i < numPostureChecks; i++ {
postureChecks = append(postureChecks, &posture.Checks{
ID: fmt.Sprintf("pc-%d", i),
AccountID: accountID,
Name: fmt.Sprintf("Posture Check %d", i),
})
}
if err := db.Create(&postureChecks).Error; err != nil {
b.Fatalf("create posture checks: %v", err)
}
var networks []*networkTypes.Network
for i := 0; i < numNetworks; i++ {
networks = append(networks, &networkTypes.Network{
ID: fmt.Sprintf("nettype-%d", i),
AccountID: accountID,
Name: fmt.Sprintf("Network Type %d", i),
})
}
if err := db.Create(&networks).Error; err != nil {
b.Fatalf("create networks: %v", err)
}
var networkRouters []*routerTypes.NetworkRouter
for i := 0; i < numNetworkRouters; i++ {
networkRouters = append(networkRouters, &routerTypes.NetworkRouter{
ID: fmt.Sprintf("router-%d", i),
AccountID: accountID,
NetworkID: networks[i%numNetworks].ID,
Peer: peers[i%numPeers].ID,
})
}
if err := db.Create(&networkRouters).Error; err != nil {
b.Fatalf("create network routers: %v", err)
}
var networkResources []*resourceTypes.NetworkResource
for i := 0; i < numNetworkResources; i++ {
networkResources = append(networkResources, &resourceTypes.NetworkResource{
ID: fmt.Sprintf("resource-%d", i),
AccountID: accountID,
NetworkID: networks[i%numNetworks].ID,
Name: fmt.Sprintf("Resource %d", i),
})
}
if err := db.Create(&networkResources).Error; err != nil {
b.Fatalf("create network resources: %v", err)
}
onboarding := types.AccountOnboarding{
AccountID: accountID,
OnboardingFlowPending: true,
}
if err := db.Create(&onboarding).Error; err != nil {
b.Fatalf("create onboarding: %v", err)
}
return store, accountID
}
func BenchmarkGetAccount(b *testing.B) {
store, accountID := setupBenchmarkDB(b)
ctx := context.Background()
b.ResetTimer()
b.ReportAllocs()
b.Run("old", func(b *testing.B) {
for range b.N {
_, err := store.GetAccountSlow(ctx, accountID)
if err != nil {
b.Fatalf("GetAccountSlow failed: %v", err)
}
}
})
b.Run("new", func(b *testing.B) {
for range b.N {
_, err := store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("GetAccountFast failed: %v", err)
}
}
})
b.Run("raw", func(b *testing.B) {
for range b.N {
_, err := store.GetAccountPureSQL(ctx, accountID)
if err != nil {
b.Fatalf("GetAccountPureSQL failed: %v", err)
}
}
})
store.pool.Close()
}
func TestAccountEquivalence(t *testing.T) {
store, accountID := setupBenchmarkDB(t)
ctx := context.Background()
type getAccountFunc func(context.Context, string) (*types.Account, error)
tests := []struct {
name string
expectedF getAccountFunc
actualF getAccountFunc
}{
// {"old vs new", store.GetAccountSlow, store.GetAccount},
{"old vs raw", store.GetAccountSlow, store.GetAccountPureSQL},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
expected, errOld := tt.expectedF(ctx, accountID)
assert.NoError(t, errOld, "expected function should not return an error")
assert.NotNil(t, expected, "expected should not be nil")
actual, errNew := tt.actualF(ctx, accountID)
assert.NoError(t, errNew, "actual function should not return an error")
assert.NotNil(t, actual, "actual should not be nil")
testAccountEquivalence(t, expected, actual)
})
}
expected, errOld := store.GetAccountSlow(ctx, accountID)
assert.NoError(t, errOld, "GetAccountSlow should not return an error")
assert.NotNil(t, expected, "expected should not be nil")
actual, errNew := store.GetAccount(ctx, accountID)
assert.NoError(t, errNew, "GetAccount (new) should not return an error")
assert.NotNil(t, actual, "actual should not be nil")
}
func testAccountEquivalence(t *testing.T, expected, actual *types.Account) {
// normalizeNilSlices(expected)
// normalizeNilSlices(actual)
assert.Equal(t, expected.Id, actual.Id, "Account IDs should be equal")
assert.Equal(t, expected.CreatedBy, actual.CreatedBy, "Account CreatedBy fields should be equal")
assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second, "Account CreatedAt timestamps should be within a second")
assert.Equal(t, expected.Domain, actual.Domain, "Account Domains should be equal")
assert.Equal(t, expected.DomainCategory, actual.DomainCategory, "Account DomainCategories should be equal")
assert.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount, "Account IsDomainPrimaryAccount flags should be equal")
assert.Equal(t, expected.Network, actual.Network, "Embedded Account Network structs should be equal")
assert.Equal(t, expected.DNSSettings, actual.DNSSettings, "Embedded Account DNSSettings structs should be equal")
assert.Equal(t, expected.Onboarding, actual.Onboarding, "Embedded Account Onboarding structs should be equal")
assert.Len(t, actual.SetupKeys, len(expected.SetupKeys), "SetupKeys maps should have the same number of elements")
for key, oldVal := range expected.SetupKeys {
newVal, ok := actual.SetupKeys[key]
assert.True(t, ok, "SetupKey with key '%s' should exist in new account", key)
assert.Equal(t, *oldVal, *newVal, "SetupKey with key '%s' should be equal", key)
}
assert.Len(t, actual.Peers, len(expected.Peers), "Peers maps should have the same number of elements")
for key, oldVal := range expected.Peers {
newVal, ok := actual.Peers[key]
assert.True(t, ok, "Peer with ID '%s' should exist in new account", key)
assert.Equal(t, *oldVal, *newVal, "Peer with ID '%s' should be equal", key)
}
assert.Len(t, actual.Users, len(expected.Users), "Users maps should have the same number of elements")
for key, oldUser := range expected.Users {
newUser, ok := actual.Users[key]
assert.True(t, ok, "User with ID '%s' should exist in new account", key)
assert.Len(t, newUser.PATs, len(oldUser.PATs), "PATs map for user '%s' should have the same size", key)
for patKey, oldPAT := range oldUser.PATs {
newPAT, patOk := newUser.PATs[patKey]
assert.True(t, patOk, "PAT with ID '%s' for user '%s' should exist in new user object", patKey, key)
assert.Equal(t, *oldPAT, *newPAT, "PAT with ID '%s' for user '%s' should be equal", patKey, key)
}
oldUser.PATs = nil
newUser.PATs = nil
assert.Equal(t, *oldUser, *newUser, "User struct for ID '%s' (without PATs) should be equal", key)
}
assert.Len(t, actual.Groups, len(expected.Groups), "Groups maps should have the same number of elements")
for key, oldVal := range expected.Groups {
newVal, ok := actual.Groups[key]
assert.True(t, ok, "Group with ID '%s' should exist in new account", key)
sort.Strings(oldVal.Peers)
sort.Strings(newVal.Peers)
assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key)
}
assert.Len(t, actual.Routes, len(expected.Routes), "Routes maps should have the same number of elements")
for key, oldVal := range expected.Routes {
newVal, ok := actual.Routes[key]
assert.True(t, ok, "Route with ID '%s' should exist in new account", key)
assert.Equal(t, *oldVal, *newVal, "Route with ID '%s' should be equal", key)
}
assert.Len(t, actual.NameServerGroups, len(expected.NameServerGroups), "NameServerGroups maps should have the same number of elements")
for key, oldVal := range expected.NameServerGroups {
newVal, ok := actual.NameServerGroups[key]
assert.True(t, ok, "NameServerGroup with ID '%s' should exist in new account", key)
assert.Equal(t, *oldVal, *newVal, "NameServerGroup with ID '%s' should be equal", key)
}
assert.Len(t, actual.Policies, len(expected.Policies), "Policies slices should have the same number of elements")
sort.Slice(expected.Policies, func(i, j int) bool { return expected.Policies[i].ID < expected.Policies[j].ID })
sort.Slice(actual.Policies, func(i, j int) bool { return actual.Policies[i].ID < actual.Policies[j].ID })
for i := range expected.Policies {
sort.Slice(expected.Policies[i].Rules, func(j, k int) bool { return expected.Policies[i].Rules[j].ID < expected.Policies[i].Rules[k].ID })
sort.Slice(actual.Policies[i].Rules, func(j, k int) bool { return actual.Policies[i].Rules[j].ID < actual.Policies[i].Rules[k].ID })
assert.Equal(t, *expected.Policies[i], *actual.Policies[i], "Policy with ID '%s' should be equal", expected.Policies[i].ID)
}
assert.Len(t, actual.PostureChecks, len(expected.PostureChecks), "PostureChecks slices should have the same number of elements")
sort.Slice(expected.PostureChecks, func(i, j int) bool { return expected.PostureChecks[i].ID < expected.PostureChecks[j].ID })
sort.Slice(actual.PostureChecks, func(i, j int) bool { return actual.PostureChecks[i].ID < actual.PostureChecks[j].ID })
for i := range expected.PostureChecks {
assert.Equal(t, *expected.PostureChecks[i], *actual.PostureChecks[i], "PostureCheck with ID '%s' should be equal", expected.PostureChecks[i].ID)
}
assert.Len(t, actual.Networks, len(expected.Networks), "Networks slices should have the same number of elements")
sort.Slice(expected.Networks, func(i, j int) bool { return expected.Networks[i].ID < expected.Networks[j].ID })
sort.Slice(actual.Networks, func(i, j int) bool { return actual.Networks[i].ID < actual.Networks[j].ID })
for i := range expected.Networks {
assert.Equal(t, *expected.Networks[i], *actual.Networks[i], "Network with ID '%s' should be equal", expected.Networks[i].ID)
}
assert.Len(t, actual.NetworkRouters, len(expected.NetworkRouters), "NetworkRouters slices should have the same number of elements")
sort.Slice(expected.NetworkRouters, func(i, j int) bool { return expected.NetworkRouters[i].ID < expected.NetworkRouters[j].ID })
sort.Slice(actual.NetworkRouters, func(i, j int) bool { return actual.NetworkRouters[i].ID < actual.NetworkRouters[j].ID })
for i := range expected.NetworkRouters {
assert.Equal(t, *expected.NetworkRouters[i], *actual.NetworkRouters[i], "NetworkRouter with ID '%s' should be equal", expected.NetworkRouters[i].ID)
}
assert.Len(t, actual.NetworkResources, len(expected.NetworkResources), "NetworkResources slices should have the same number of elements")
sort.Slice(expected.NetworkResources, func(i, j int) bool { return expected.NetworkResources[i].ID < expected.NetworkResources[j].ID })
sort.Slice(actual.NetworkResources, func(i, j int) bool { return actual.NetworkResources[i].ID < actual.NetworkResources[j].ID })
for i := range expected.NetworkResources {
assert.Equal(t, *expected.NetworkResources[i], *actual.NetworkResources[i], "NetworkResource with ID '%s' should be equal", expected.NetworkResources[i].ID)
}
}
func normalizeNilSlices(acc *types.Account) {
if acc == nil {
return
}
if acc.Policies == nil {
acc.Policies = []*types.Policy{}
}
if acc.PostureChecks == nil {
acc.PostureChecks = []*posture.Checks{}
}
if acc.Networks == nil {
acc.Networks = []*networkTypes.Network{}
}
if acc.NetworkRouters == nil {
acc.NetworkRouters = []*routerTypes.NetworkRouter{}
}
if acc.NetworkResources == nil {
acc.NetworkResources = []*resourceTypes.NetworkResource{}
}
if acc.DNSSettings.DisabledManagementGroups == nil {
acc.DNSSettings.DisabledManagementGroups = []string{}
}
for _, key := range acc.SetupKeys {
if key.AutoGroups == nil {
key.AutoGroups = []string{}
}
}
for _, peer := range acc.Peers {
if peer.ExtraDNSLabels == nil {
peer.ExtraDNSLabels = []string{}
}
}
for _, user := range acc.Users {
if user.AutoGroups == nil {
user.AutoGroups = []string{}
}
}
for _, group := range acc.Groups {
if group.Peers == nil {
group.Peers = []string{}
}
if group.Resources == nil {
group.Resources = []types.Resource{}
}
if group.GroupPeers == nil {
group.GroupPeers = []types.GroupPeer{}
}
}
for _, route := range acc.Routes {
if route.Domains == nil {
route.Domains = domain.List{}
}
if route.PeerGroups == nil {
route.PeerGroups = []string{}
}
if route.Groups == nil {
route.Groups = []string{}
}
if route.AccessControlGroups == nil {
route.AccessControlGroups = []string{}
}
}
for _, nsg := range acc.NameServerGroups {
if nsg.NameServers == nil {
nsg.NameServers = []nbdns.NameServer{}
}
if nsg.Groups == nil {
nsg.Groups = []string{}
}
if nsg.Domains == nil {
nsg.Domains = []string{}
}
}
for _, policy := range acc.Policies {
if policy.SourcePostureChecks == nil {
policy.SourcePostureChecks = []string{}
}
if policy.Rules == nil {
policy.Rules = []*types.PolicyRule{}
}
for _, rule := range policy.Rules {
if rule.Destinations == nil {
rule.Destinations = []string{}
}
if rule.Sources == nil {
rule.Sources = []string{}
}
if rule.Ports == nil {
rule.Ports = []string{}
}
if rule.PortRanges == nil {
rule.PortRanges = []types.RulePortRange{}
}
}
}
for _, check := range acc.PostureChecks {
if check.Checks.GeoLocationCheck != nil {
if check.Checks.GeoLocationCheck.Locations == nil {
check.Checks.GeoLocationCheck.Locations = []posture.Location{}
}
}
if check.Checks.PeerNetworkRangeCheck != nil {
if check.Checks.PeerNetworkRangeCheck.Ranges == nil {
check.Checks.PeerNetworkRangeCheck.Ranges = []netip.Prefix{}
}
}
if check.Checks.ProcessCheck != nil {
if check.Checks.ProcessCheck.Processes == nil {
check.Checks.ProcessCheck.Processes = []posture.Process{}
}
}
}
for _, router := range acc.NetworkRouters {
if router.PeerGroups == nil {
router.PeerGroups = []string{}
}
}
}
func TestGetAccountEquals(t *testing.T) {
store, accountID := setupBenchmarkDB(t)
ctx := context.Background()
expected, err := store.GetAccountSlow(ctx, accountID)
require.NoError(t, err)
actual, err := store.GetAccount(ctx, accountID)
require.NoError(t, err)
require.Equal(t, expected.DNSSettings, actual.DNSSettings)
require.Equal(t, expected.Domain, actual.Domain)
require.Equal(t, expected.DomainCategory, actual.DomainCategory)
require.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount)
require.Equal(t, expected.Id, actual.Id)
require.Equal(t, expected.CreatedBy, actual.CreatedBy)
require.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second)
require.Equal(t, len(expected.SetupKeys), len(actual.SetupKeys))
for k, v := range expected.SetupKeys {
v2, ok := actual.SetupKeys[k]
require.True(t, ok)
require.Equal(t, v, v2)
}
require.Equal(t, len(expected.Peers), len(actual.Peers))
for k, v := range expected.Peers {
v2, ok := actual.Peers[k]
require.True(t, ok)
require.Equal(t, v, v2)
}
require.Equal(t, len(expected.Users), len(actual.Users))
for k, v := range expected.Users {
v2, ok := actual.Users[k]
require.True(t, ok)
require.Equal(t, v, v2)
require.Equal(t, len(v.PATs), len(v2.PATs))
for k3, v3 := range v.PATs {
v4, ok := v2.PATs[k3]
require.True(t, ok)
require.Equal(t, v3, v4)
}
}
require.Equal(t, len(expected.Groups), len(actual.Groups))
for k, v := range expected.Groups {
v2, ok := actual.Groups[k]
require.True(t, ok)
require.Equal(t, v, v2)
}
require.Equal(t, len(expected.Routes), len(actual.Routes))
for k, v := range expected.Routes {
v2, ok := actual.Routes[k]
require.True(t, ok)
require.Equal(t, v, v2)
}
require.Equal(t, len(expected.NameServerGroups), len(actual.NameServerGroups))
for k, v := range expected.NameServerGroups {
v2, ok := actual.NameServerGroups[k]
require.True(t, ok)
require.Equal(t, v, v2)
}
require.Equal(t, expected.Policies, actual.Policies)
require.Equal(t, expected.PostureChecks, actual.PostureChecks)
require.Equal(t, expected.Network, actual.Network)
require.Equal(t, expected.Networks, actual.Networks)
require.Equal(t, expected.NetworkRouters, actual.NetworkRouters)
require.Equal(t, expected.NetworkResources, actual.NetworkResources)
require.Equal(t, expected.Onboarding, actual.Onboarding)
}
func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) {
var account types.Account
account.Network = &types.Network{}
const accountQuery = `
SELECT
id, created_by, created_at, domain, domain_category, is_domain_primary_account,
-- Embedded Network
network_identifier, network_net, network_dns, network_serial,
-- Embedded DNSSettings
dns_settings_disabled_management_groups,
-- Embedded Settings
settings_peer_login_expiration_enabled, settings_peer_login_expiration,
settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration,
settings_regular_users_view_blocked, settings_groups_propagation_enabled,
settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups,
settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range,
settings_lazy_connection_enabled,
-- Embedded ExtraSettings
settings_extra_peer_approval_enabled, settings_extra_user_approval_required,
settings_extra_integrated_validator, settings_extra_integrated_validator_groups
FROM accounts WHERE id = $1`
var networkNet, dnsSettingsDisabledGroups []byte
var (
sPeerLoginExpirationEnabled sql.NullBool
sPeerLoginExpiration sql.NullInt64
sPeerInactivityExpirationEnabled sql.NullBool
sPeerInactivityExpiration sql.NullInt64
sRegularUsersViewBlocked sql.NullBool
sGroupsPropagationEnabled sql.NullBool
sJWTGroupsEnabled sql.NullBool
sJWTGroupsClaimName sql.NullString
sJWTAllowGroups []byte
sRoutingPeerDNSResolutionEnabled sql.NullBool
sDNSDomain sql.NullString
sNetworkRange []byte
sLazyConnectionEnabled sql.NullBool
sExtraPeerApprovalEnabled sql.NullBool
sExtraUserApprovalRequired sql.NullBool
sExtraIntegratedValidator sql.NullString
sExtraIntegratedValidatorGroups []byte
)
err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan(
&account.Id, &account.CreatedBy, &account.CreatedAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount,
&account.Network.Identifier, &networkNet, &account.Network.Dns, &account.Network.Serial,
&dnsSettingsDisabledGroups,
&sPeerLoginExpirationEnabled, &sPeerLoginExpiration,
&sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration,
&sRegularUsersViewBlocked, &sGroupsPropagationEnabled,
&sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups,
&sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange,
&sLazyConnectionEnabled,
&sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired,
&sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups,
)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, errors.New("account not found")
}
return nil, err
}
_ = json.Unmarshal(networkNet, &account.Network.Net)
_ = json.Unmarshal(dnsSettingsDisabledGroups, &account.DNSSettings.DisabledManagementGroups)
account.Settings = &types.Settings{Extra: &types.ExtraSettings{}}
if sPeerLoginExpirationEnabled.Valid {
account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool
}
if sPeerLoginExpiration.Valid {
account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64)
}
if sPeerInactivityExpirationEnabled.Valid {
account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool
}
if sPeerInactivityExpiration.Valid {
account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64)
}
if sRegularUsersViewBlocked.Valid {
account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool
}
if sGroupsPropagationEnabled.Valid {
account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool
}
if sJWTGroupsEnabled.Valid {
account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool
}
if sJWTGroupsClaimName.Valid {
account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String
}
if sRoutingPeerDNSResolutionEnabled.Valid {
account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool
}
if sDNSDomain.Valid {
account.Settings.DNSDomain = sDNSDomain.String
}
if sLazyConnectionEnabled.Valid {
account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool
}
if sJWTAllowGroups != nil {
_ = json.Unmarshal(sJWTAllowGroups, &account.Settings.JWTAllowGroups)
}
if sNetworkRange != nil {
_ = json.Unmarshal(sNetworkRange, &account.Settings.NetworkRange)
}
if sExtraPeerApprovalEnabled.Valid {
account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool
}
if sExtraUserApprovalRequired.Valid {
account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool
}
if sExtraIntegratedValidator.Valid {
account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String
}
if sExtraIntegratedValidatorGroups != nil {
_ = json.Unmarshal(sExtraIntegratedValidatorGroups, &account.Settings.Extra.IntegratedValidatorGroups)
}
var wg sync.WaitGroup
errChan := make(chan error, 12)
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) {
var sk types.SetupKey
var autoGroups []byte
var expiresAt, updatedAt, lastUsed sql.NullTime
var revoked, ephemeral, allowExtraDNSLabels sql.NullBool
var usedTimes, usageLimit sql.NullInt64
err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels)
if err == nil {
if expiresAt.Valid {
sk.ExpiresAt = &expiresAt.Time
}
if updatedAt.Valid {
sk.UpdatedAt = updatedAt.Time
if sk.UpdatedAt.IsZero() {
sk.UpdatedAt = sk.CreatedAt
}
}
if lastUsed.Valid {
sk.LastUsed = &lastUsed.Time
}
if revoked.Valid {
sk.Revoked = revoked.Bool
}
if usedTimes.Valid {
sk.UsedTimes = int(usedTimes.Int64)
}
if usageLimit.Valid {
sk.UsageLimit = int(usageLimit.Int64)
}
if ephemeral.Valid {
sk.Ephemeral = ephemeral.Bool
}
if allowExtraDNSLabels.Valid {
sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool
}
if autoGroups != nil {
_ = json.Unmarshal(autoGroups, &sk.AutoGroups)
} else {
sk.AutoGroups = []string{}
}
}
return sk, err
})
if err != nil {
errChan <- err
return
}
account.SetupKeysG = keys
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, location_geo_name_id FROM peers WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) {
var p nbpeer.Peer
p.Status = &nbpeer.PeerStatus{}
var ip, extraDNS, netAddr, env, flags, files, connIP []byte
err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &p.SSHEnabled, &p.LoginExpirationEnabled, &p.InactivityExpirationEnabled, &p.LastLogin, &p.CreatedAt, &p.Ephemeral, &extraDNS, &p.AllowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &p.Status.LastSeen, &p.Status.Connected, &p.Status.LoginExpired, &p.Status.RequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID)
if err == nil {
if ip != nil {
_ = json.Unmarshal(ip, &p.IP)
}
if extraDNS != nil {
_ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels)
}
if netAddr != nil {
_ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses)
}
if env != nil {
_ = json.Unmarshal(env, &p.Meta.Environment)
}
if flags != nil {
_ = json.Unmarshal(flags, &p.Meta.Flags)
}
if files != nil {
_ = json.Unmarshal(files, &p.Meta.Files)
}
if connIP != nil {
_ = json.Unmarshal(connIP, &p.Location.ConnectionIP)
}
}
return p, err
})
if err != nil {
errChan <- err
return
}
account.PeersG = peers
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) {
var u types.User
var autoGroups []byte
var lastLogin sql.NullTime
var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType)
if err == nil {
if lastLogin.Valid {
u.LastLogin = &lastLogin.Time
}
if isServiceUser.Valid {
u.IsServiceUser = isServiceUser.Bool
}
if nonDeletable.Valid {
u.NonDeletable = nonDeletable.Bool
}
if blocked.Valid {
u.Blocked = blocked.Bool
}
if pendingApproval.Valid {
u.PendingApproval = pendingApproval.Bool
}
if autoGroups != nil {
_ = json.Unmarshal(autoGroups, &u.AutoGroups)
} else {
u.AutoGroups = []string{}
}
}
return u, err
})
if err != nil {
errChan <- err
return
}
account.UsersG = users
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) {
var g types.Group
var resources []byte
var refID sql.NullInt64
var refType sql.NullString
err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType)
if err == nil {
if refID.Valid {
g.IntegrationReference.ID = int(refID.Int64)
}
if refType.Valid {
g.IntegrationReference.IntegrationType = refType.String
}
if resources != nil {
_ = json.Unmarshal(resources, &g.Resources)
} else {
g.Resources = []types.Resource{}
}
g.GroupPeers = []types.GroupPeer{}
g.Peers = []string{}
}
return &g, err
})
if err != nil {
errChan <- err
return
}
account.GroupsG = groups
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) {
var p types.Policy
var checks []byte
err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &p.Enabled, &checks)
if err == nil && checks != nil {
_ = json.Unmarshal(checks, &p.SourcePostureChecks)
}
return &p, err
})
if err != nil {
errChan <- err
return
}
account.Policies = policies
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) {
var r route.Route
var network, domains, peerGroups, groups, accessGroups []byte
err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &r.KeepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &r.Masquerade, &r.Metric, &r.Enabled, &groups, &accessGroups, &r.SkipAutoApply)
if err == nil {
if network != nil {
_ = json.Unmarshal(network, &r.Network)
}
if domains != nil {
_ = json.Unmarshal(domains, &r.Domains)
}
if peerGroups != nil {
_ = json.Unmarshal(peerGroups, &r.PeerGroups)
}
if groups != nil {
_ = json.Unmarshal(groups, &r.Groups)
}
if accessGroups != nil {
_ = json.Unmarshal(accessGroups, &r.AccessControlGroups)
}
}
return r, err
})
if err != nil {
errChan <- err
return
}
account.RoutesG = routes
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) {
var n nbdns.NameServerGroup
var ns, groups, domains []byte
err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &n.Primary, &domains, &n.Enabled, &n.SearchDomainsEnabled)
if err == nil {
if ns != nil {
_ = json.Unmarshal(ns, &n.NameServers)
} else {
n.NameServers = []nbdns.NameServer{}
}
if groups != nil {
_ = json.Unmarshal(groups, &n.Groups)
} else {
n.Groups = []string{}
}
if domains != nil {
_ = json.Unmarshal(domains, &n.Domains)
} else {
n.Domains = []string{}
}
}
return n, err
})
if err != nil {
errChan <- err
return
}
account.NameServerGroupsG = nsgs
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) {
var c posture.Checks
var checksDef []byte
err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef)
if err == nil && checksDef != nil {
_ = json.Unmarshal(checksDef, &c.Checks)
}
return &c, err
})
if err != nil {
errChan <- err
return
}
account.PostureChecks = checks
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network])
if err != nil {
errChan <- err
return
}
account.Networks = make([]*networkTypes.Network, len(networks))
for i := range networks {
account.Networks[i] = &networks[i]
}
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*routerTypes.NetworkRouter, error) {
var r routerTypes.NetworkRouter
var peerGroups []byte
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &r.Masquerade, &r.Metric, &r.Enabled)
if err == nil && peerGroups != nil {
_ = json.Unmarshal(peerGroups, &r.PeerGroups)
}
return &r, err
})
if err != nil {
errChan <- err
return
}
account.NetworkRouters = routers
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
errChan <- err
return
}
resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*resourceTypes.NetworkResource, error) {
var r resourceTypes.NetworkResource
var prefix []byte
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &r.Enabled)
if err == nil && prefix != nil {
_ = json.Unmarshal(prefix, &r.Prefix)
}
return &r, err
})
if err != nil {
errChan <- err
return
}
account.NetworkResources = resources
}()
wg.Add(1)
go func() {
defer wg.Done()
const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1`
err := s.pool.QueryRow(ctx, query, accountID).Scan(
&account.Onboarding.AccountID,
&account.Onboarding.OnboardingFlowPending,
&account.Onboarding.SignupFormPending,
&account.Onboarding.CreatedAt,
&account.Onboarding.UpdatedAt,
)
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
errChan <- err
}
}()
wg.Wait()
close(errChan)
for e := range errChan {
if e != nil {
return nil, e
}
}
var userIDs []string
for _, u := range account.UsersG {
userIDs = append(userIDs, u.Id)
}
var policyIDs []string
for _, p := range account.Policies {
policyIDs = append(policyIDs, p.ID)
}
var groupIDs []string
for _, g := range account.GroupsG {
groupIDs = append(groupIDs, g.ID)
}
wg.Add(3)
errChan = make(chan error, 3)
var pats []types.PersonalAccessToken
go func() {
defer wg.Done()
if len(userIDs) == 0 {
return
}
const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)`
rows, err := s.pool.Query(ctx, query, userIDs)
if err != nil {
errChan <- err
return
}
pats, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.PersonalAccessToken])
if err != nil {
errChan <- err
}
}()
var rules []*types.PolicyRule
go func() {
defer wg.Done()
if len(policyIDs) == 0 {
return
}
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)`
rows, err := s.pool.Query(ctx, query, policyIDs)
if err != nil {
errChan <- err
return
}
rules, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) {
var r types.PolicyRule
var dest, destRes, sources, sourceRes, ports, portRanges []byte
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &r.Enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &r.Bidirectional, &r.Protocol, &ports, &portRanges)
if err == nil {
if dest != nil {
_ = json.Unmarshal(dest, &r.Destinations)
}
if destRes != nil {
_ = json.Unmarshal(destRes, &r.DestinationResource)
}
if sources != nil {
_ = json.Unmarshal(sources, &r.Sources)
}
if sourceRes != nil {
_ = json.Unmarshal(sourceRes, &r.SourceResource)
}
if ports != nil {
_ = json.Unmarshal(ports, &r.Ports)
}
if portRanges != nil {
_ = json.Unmarshal(portRanges, &r.PortRanges)
}
}
return &r, err
})
if err != nil {
errChan <- err
}
}()
var groupPeers []types.GroupPeer
go func() {
defer wg.Done()
if len(groupIDs) == 0 {
return
}
const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)`
rows, err := s.pool.Query(ctx, query, groupIDs)
if err != nil {
errChan <- err
return
}
groupPeers, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer])
if err != nil {
errChan <- err
}
}()
wg.Wait()
close(errChan)
for e := range errChan {
if e != nil {
return nil, e
}
}
patsByUserID := make(map[string][]*types.PersonalAccessToken)
for i := range pats {
pat := &pats[i]
patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat)
pat.UserID = ""
}
rulesByPolicyID := make(map[string][]*types.PolicyRule)
for _, rule := range rules {
rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule)
}
peersByGroupID := make(map[string][]string)
for _, gp := range groupPeers {
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
}
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
for i := range account.SetupKeysG {
key := &account.SetupKeysG[i]
account.SetupKeys[key.Key] = key
}
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for i := range account.PeersG {
peer := &account.PeersG[i]
account.Peers[peer.ID] = peer
}
account.Users = make(map[string]*types.User, len(account.UsersG))
for i := range account.UsersG {
user := &account.UsersG[i]
user.PATs = make(map[string]*types.PersonalAccessToken)
if userPats, ok := patsByUserID[user.Id]; ok {
for j := range userPats {
pat := userPats[j]
user.PATs[pat.ID] = pat
}
}
account.Users[user.Id] = user
}
for i := range account.Policies {
policy := account.Policies[i]
if policyRules, ok := rulesByPolicyID[policy.ID]; ok {
policy.Rules = policyRules
}
}
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for i := range account.GroupsG {
group := account.GroupsG[i]
if peerIDs, ok := peersByGroupID[group.ID]; ok {
group.Peers = peerIDs
}
account.Groups[group.ID] = group
}
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for i := range account.RoutesG {
route := &account.RoutesG[i]
account.Routes[route.ID] = route
}
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for i := range account.NameServerGroupsG {
nsg := &account.NameServerGroupsG[i]
nsg.AccountID = ""
account.NameServerGroups[nsg.ID] = nsg
}
account.SetupKeysG = nil
account.PeersG = nil
account.UsersG = nil
account.GroupsG = nil
account.RoutesG = nil
account.NameServerGroupsG = nil
return &account, nil
}