mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 00:06:38 +00:00
1494 lines
48 KiB
Go
1494 lines
48 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"sort"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
nbdns "github.com/netbirdio/netbird/dns"
|
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
"github.com/netbirdio/netbird/management/server/posture"
|
|
"github.com/netbirdio/netbird/management/server/types"
|
|
"github.com/netbirdio/netbird/route"
|
|
"github.com/netbirdio/netbird/shared/management/domain"
|
|
"github.com/netbirdio/netbird/shared/management/status"
|
|
)
|
|
|
|
func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types.Account, error) {
|
|
start := time.Now()
|
|
defer func() {
|
|
elapsed := time.Since(start)
|
|
if elapsed > 1*time.Second {
|
|
log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed)
|
|
}
|
|
}()
|
|
|
|
var account types.Account
|
|
result := s.db.Model(&account).
|
|
Omit("GroupsG").
|
|
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
|
|
Preload(clause.Associations).
|
|
Take(&account, idQueryCondition, accountID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewAccountNotFoundError(accountID)
|
|
}
|
|
return nil, status.NewGetAccountFromStoreError(result.Error)
|
|
}
|
|
|
|
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
|
|
for i, policy := range account.Policies {
|
|
var rules []*types.PolicyRule
|
|
err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
|
if err != nil {
|
|
return nil, status.Errorf(status.NotFound, "rule not found")
|
|
}
|
|
account.Policies[i].Rules = rules
|
|
}
|
|
|
|
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
|
|
for _, key := range account.SetupKeysG {
|
|
account.SetupKeys[key.Key] = key.Copy()
|
|
}
|
|
account.SetupKeysG = nil
|
|
|
|
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
|
|
for _, peer := range account.PeersG {
|
|
account.Peers[peer.ID] = peer.Copy()
|
|
}
|
|
account.PeersG = nil
|
|
|
|
account.Users = make(map[string]*types.User, len(account.UsersG))
|
|
for _, user := range account.UsersG {
|
|
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs))
|
|
for _, pat := range user.PATsG {
|
|
user.PATs[pat.ID] = pat.Copy()
|
|
}
|
|
account.Users[user.Id] = user.Copy()
|
|
}
|
|
account.UsersG = nil
|
|
|
|
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
|
|
for _, group := range account.GroupsG {
|
|
account.Groups[group.ID] = group.Copy()
|
|
}
|
|
account.GroupsG = nil
|
|
|
|
var groupPeers []types.GroupPeer
|
|
s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID).
|
|
Find(&groupPeers)
|
|
for _, groupPeer := range groupPeers {
|
|
if group, ok := account.Groups[groupPeer.GroupID]; ok {
|
|
group.Peers = append(group.Peers, groupPeer.PeerID)
|
|
} else {
|
|
log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID)
|
|
}
|
|
}
|
|
|
|
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
|
|
for _, route := range account.RoutesG {
|
|
account.Routes[route.ID] = route.Copy()
|
|
}
|
|
account.RoutesG = nil
|
|
|
|
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
|
|
for _, ns := range account.NameServerGroupsG {
|
|
account.NameServerGroups[ns.ID] = ns.Copy()
|
|
}
|
|
account.NameServerGroupsG = nil
|
|
|
|
return &account, nil
|
|
}
|
|
|
|
func connectDBforTest(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
|
|
config, err := pgxpool.ParseConfig(dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to parse database config: %w", err)
|
|
}
|
|
|
|
config.MaxConns = 10
|
|
config.MinConns = 2
|
|
config.MaxConnLifetime = time.Hour
|
|
config.HealthCheckPeriod = time.Minute
|
|
|
|
pool, err := pgxpool.NewWithConfig(ctx, config)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to create connection pool: %w", err)
|
|
}
|
|
|
|
if err := pool.Ping(ctx); err != nil {
|
|
pool.Close()
|
|
return nil, fmt.Errorf("unable to ping database: %w", err)
|
|
}
|
|
|
|
fmt.Println("Successfully connected to the database!")
|
|
return pool, nil
|
|
}
|
|
|
|
func setupBenchmarkDB(b testing.TB) (*SqlStore, string) {
|
|
dsn := "host=localhost user=postgres password=mysecretpassword dbname=testdb port=5432 sslmode=disable TimeZone=Europe/Berlin"
|
|
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
|
if err != nil {
|
|
b.Fatalf("failed to connect database: %v", err)
|
|
}
|
|
|
|
pool, err := connectDB(context.Background(), dsn)
|
|
if err != nil {
|
|
b.Fatalf("failed to connect database: %v", err)
|
|
}
|
|
|
|
models := []interface{}{
|
|
&types.Account{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{},
|
|
&types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
|
|
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
|
|
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
|
|
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
|
&types.AccountOnboarding{},
|
|
}
|
|
|
|
for i := len(models) - 1; i >= 0; i-- {
|
|
db.Migrator().DropTable(models[i])
|
|
}
|
|
|
|
err = db.AutoMigrate(models...)
|
|
if err != nil {
|
|
b.Fatalf("failed to migrate database: %v", err)
|
|
}
|
|
|
|
store := &SqlStore{
|
|
db: db,
|
|
pool: pool,
|
|
}
|
|
|
|
const (
|
|
accountID = "benchmark-account-id"
|
|
numUsers = 20
|
|
numPatsPerUser = 3
|
|
numSetupKeys = 25
|
|
numPeers = 200
|
|
numGroups = 30
|
|
numPolicies = 50
|
|
numRulesPerPolicy = 10
|
|
numRoutes = 40
|
|
numNSGroups = 10
|
|
numPostureChecks = 15
|
|
numNetworks = 5
|
|
numNetworkRouters = 5
|
|
numNetworkResources = 10
|
|
)
|
|
|
|
_, ipNet, _ := net.ParseCIDR("100.64.0.0/10")
|
|
acc := types.Account{
|
|
Id: accountID,
|
|
CreatedBy: "benchmark-user",
|
|
CreatedAt: time.Now(),
|
|
Domain: "benchmark.com",
|
|
IsDomainPrimaryAccount: true,
|
|
Network: &types.Network{
|
|
Identifier: "benchmark-net",
|
|
Net: *ipNet,
|
|
Serial: 1,
|
|
},
|
|
DNSSettings: types.DNSSettings{
|
|
DisabledManagementGroups: []string{"group-disabled-1"},
|
|
},
|
|
Settings: &types.Settings{},
|
|
}
|
|
if err := db.Create(&acc).Error; err != nil {
|
|
b.Fatalf("create account: %v", err)
|
|
}
|
|
|
|
var setupKeys []types.SetupKey
|
|
for i := 0; i < numSetupKeys; i++ {
|
|
setupKeys = append(setupKeys, types.SetupKey{
|
|
Id: fmt.Sprintf("keyid-%d", i),
|
|
AccountID: accountID,
|
|
Key: fmt.Sprintf("key-%d", i),
|
|
Name: fmt.Sprintf("Benchmark Key %d", i),
|
|
ExpiresAt: &time.Time{},
|
|
})
|
|
}
|
|
if err := db.Create(&setupKeys).Error; err != nil {
|
|
b.Fatalf("create setup keys: %v", err)
|
|
}
|
|
|
|
var peers []nbpeer.Peer
|
|
for i := 0; i < numPeers; i++ {
|
|
peers = append(peers, nbpeer.Peer{
|
|
ID: fmt.Sprintf("peer-%d", i),
|
|
AccountID: accountID,
|
|
Key: fmt.Sprintf("peerkey-%d", i),
|
|
IP: net.ParseIP(fmt.Sprintf("100.64.0.%d", i+1)),
|
|
Name: fmt.Sprintf("peer-name-%d", i),
|
|
Status: &nbpeer.PeerStatus{Connected: i%2 == 0, LastSeen: time.Now()},
|
|
})
|
|
}
|
|
if err := db.Create(&peers).Error; err != nil {
|
|
b.Fatalf("create peers: %v", err)
|
|
}
|
|
|
|
for i := 0; i < numUsers; i++ {
|
|
userID := fmt.Sprintf("user-%d", i)
|
|
user := types.User{Id: userID, AccountID: accountID}
|
|
if err := db.Create(&user).Error; err != nil {
|
|
b.Fatalf("create user %s: %v", userID, err)
|
|
}
|
|
|
|
var pats []types.PersonalAccessToken
|
|
for j := 0; j < numPatsPerUser; j++ {
|
|
pats = append(pats, types.PersonalAccessToken{
|
|
ID: fmt.Sprintf("pat-%d-%d", i, j),
|
|
UserID: userID,
|
|
Name: fmt.Sprintf("PAT %d for User %d", j, i),
|
|
})
|
|
}
|
|
if err := db.Create(&pats).Error; err != nil {
|
|
b.Fatalf("create pats for user %s: %v", userID, err)
|
|
}
|
|
}
|
|
|
|
var groups []*types.Group
|
|
for i := 0; i < numGroups; i++ {
|
|
groups = append(groups, &types.Group{
|
|
ID: fmt.Sprintf("group-%d", i),
|
|
AccountID: accountID,
|
|
Name: fmt.Sprintf("Group %d", i),
|
|
})
|
|
}
|
|
if err := db.Create(&groups).Error; err != nil {
|
|
b.Fatalf("create groups: %v", err)
|
|
}
|
|
|
|
for i := 0; i < numPolicies; i++ {
|
|
policyID := fmt.Sprintf("policy-%d", i)
|
|
policy := types.Policy{ID: policyID, AccountID: accountID, Name: fmt.Sprintf("Policy %d", i), Enabled: true}
|
|
if err := db.Create(&policy).Error; err != nil {
|
|
b.Fatalf("create policy %s: %v", policyID, err)
|
|
}
|
|
|
|
var rules []*types.PolicyRule
|
|
for j := 0; j < numRulesPerPolicy; j++ {
|
|
rules = append(rules, &types.PolicyRule{
|
|
ID: fmt.Sprintf("rule-%d-%d", i, j),
|
|
PolicyID: policyID,
|
|
Name: fmt.Sprintf("Rule %d for Policy %d", j, i),
|
|
Enabled: true,
|
|
Protocol: "all",
|
|
})
|
|
}
|
|
if err := db.Create(&rules).Error; err != nil {
|
|
b.Fatalf("create rules for policy %s: %v", policyID, err)
|
|
}
|
|
}
|
|
|
|
var routes []route.Route
|
|
for i := 0; i < numRoutes; i++ {
|
|
routes = append(routes, route.Route{
|
|
ID: route.ID(fmt.Sprintf("route-%d", i)),
|
|
AccountID: accountID,
|
|
Description: fmt.Sprintf("Route %d", i),
|
|
Network: netip.MustParsePrefix(fmt.Sprintf("192.168.%d.0/24", i)),
|
|
Enabled: true,
|
|
})
|
|
}
|
|
if err := db.Create(&routes).Error; err != nil {
|
|
b.Fatalf("create routes: %v", err)
|
|
}
|
|
|
|
var nsGroups []nbdns.NameServerGroup
|
|
for i := 0; i < numNSGroups; i++ {
|
|
nsGroups = append(nsGroups, nbdns.NameServerGroup{
|
|
ID: fmt.Sprintf("nsg-%d", i),
|
|
AccountID: accountID,
|
|
Name: fmt.Sprintf("NS Group %d", i),
|
|
Description: "Benchmark NS Group",
|
|
Enabled: true,
|
|
})
|
|
}
|
|
if err := db.Create(&nsGroups).Error; err != nil {
|
|
b.Fatalf("create nsgroups: %v", err)
|
|
}
|
|
|
|
var postureChecks []*posture.Checks
|
|
for i := 0; i < numPostureChecks; i++ {
|
|
postureChecks = append(postureChecks, &posture.Checks{
|
|
ID: fmt.Sprintf("pc-%d", i),
|
|
AccountID: accountID,
|
|
Name: fmt.Sprintf("Posture Check %d", i),
|
|
})
|
|
}
|
|
if err := db.Create(&postureChecks).Error; err != nil {
|
|
b.Fatalf("create posture checks: %v", err)
|
|
}
|
|
|
|
var networks []*networkTypes.Network
|
|
for i := 0; i < numNetworks; i++ {
|
|
networks = append(networks, &networkTypes.Network{
|
|
ID: fmt.Sprintf("nettype-%d", i),
|
|
AccountID: accountID,
|
|
Name: fmt.Sprintf("Network Type %d", i),
|
|
})
|
|
}
|
|
if err := db.Create(&networks).Error; err != nil {
|
|
b.Fatalf("create networks: %v", err)
|
|
}
|
|
|
|
var networkRouters []*routerTypes.NetworkRouter
|
|
for i := 0; i < numNetworkRouters; i++ {
|
|
networkRouters = append(networkRouters, &routerTypes.NetworkRouter{
|
|
ID: fmt.Sprintf("router-%d", i),
|
|
AccountID: accountID,
|
|
NetworkID: networks[i%numNetworks].ID,
|
|
Peer: peers[i%numPeers].ID,
|
|
})
|
|
}
|
|
if err := db.Create(&networkRouters).Error; err != nil {
|
|
b.Fatalf("create network routers: %v", err)
|
|
}
|
|
|
|
var networkResources []*resourceTypes.NetworkResource
|
|
for i := 0; i < numNetworkResources; i++ {
|
|
networkResources = append(networkResources, &resourceTypes.NetworkResource{
|
|
ID: fmt.Sprintf("resource-%d", i),
|
|
AccountID: accountID,
|
|
NetworkID: networks[i%numNetworks].ID,
|
|
Name: fmt.Sprintf("Resource %d", i),
|
|
})
|
|
}
|
|
if err := db.Create(&networkResources).Error; err != nil {
|
|
b.Fatalf("create network resources: %v", err)
|
|
}
|
|
|
|
onboarding := types.AccountOnboarding{
|
|
AccountID: accountID,
|
|
OnboardingFlowPending: true,
|
|
}
|
|
if err := db.Create(&onboarding).Error; err != nil {
|
|
b.Fatalf("create onboarding: %v", err)
|
|
}
|
|
|
|
return store, accountID
|
|
}
|
|
|
|
func BenchmarkGetAccount(b *testing.B) {
|
|
store, accountID := setupBenchmarkDB(b)
|
|
ctx := context.Background()
|
|
b.ResetTimer()
|
|
b.ReportAllocs()
|
|
b.Run("old", func(b *testing.B) {
|
|
for range b.N {
|
|
_, err := store.GetAccountSlow(ctx, accountID)
|
|
if err != nil {
|
|
b.Fatalf("GetAccountSlow failed: %v", err)
|
|
}
|
|
}
|
|
})
|
|
b.Run("new", func(b *testing.B) {
|
|
for range b.N {
|
|
_, err := store.GetAccount(ctx, accountID)
|
|
if err != nil {
|
|
b.Fatalf("GetAccountFast failed: %v", err)
|
|
}
|
|
}
|
|
})
|
|
b.Run("raw", func(b *testing.B) {
|
|
for range b.N {
|
|
_, err := store.GetAccountPureSQL(ctx, accountID)
|
|
if err != nil {
|
|
b.Fatalf("GetAccountPureSQL failed: %v", err)
|
|
}
|
|
}
|
|
})
|
|
store.pool.Close()
|
|
}
|
|
|
|
func TestAccountEquivalence(t *testing.T) {
|
|
store, accountID := setupBenchmarkDB(t)
|
|
ctx := context.Background()
|
|
|
|
type getAccountFunc func(context.Context, string) (*types.Account, error)
|
|
|
|
tests := []struct {
|
|
name string
|
|
expectedF getAccountFunc
|
|
actualF getAccountFunc
|
|
}{
|
|
// {"old vs new", store.GetAccountSlow, store.GetAccount},
|
|
{"old vs raw", store.GetAccountSlow, store.GetAccountPureSQL},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
expected, errOld := tt.expectedF(ctx, accountID)
|
|
assert.NoError(t, errOld, "expected function should not return an error")
|
|
assert.NotNil(t, expected, "expected should not be nil")
|
|
|
|
actual, errNew := tt.actualF(ctx, accountID)
|
|
assert.NoError(t, errNew, "actual function should not return an error")
|
|
assert.NotNil(t, actual, "actual should not be nil")
|
|
testAccountEquivalence(t, expected, actual)
|
|
})
|
|
}
|
|
|
|
expected, errOld := store.GetAccountSlow(ctx, accountID)
|
|
assert.NoError(t, errOld, "GetAccountSlow should not return an error")
|
|
assert.NotNil(t, expected, "expected should not be nil")
|
|
|
|
actual, errNew := store.GetAccount(ctx, accountID)
|
|
assert.NoError(t, errNew, "GetAccount (new) should not return an error")
|
|
assert.NotNil(t, actual, "actual should not be nil")
|
|
}
|
|
|
|
func testAccountEquivalence(t *testing.T, expected, actual *types.Account) {
|
|
// normalizeNilSlices(expected)
|
|
// normalizeNilSlices(actual)
|
|
|
|
assert.Equal(t, expected.Id, actual.Id, "Account IDs should be equal")
|
|
assert.Equal(t, expected.CreatedBy, actual.CreatedBy, "Account CreatedBy fields should be equal")
|
|
assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second, "Account CreatedAt timestamps should be within a second")
|
|
assert.Equal(t, expected.Domain, actual.Domain, "Account Domains should be equal")
|
|
assert.Equal(t, expected.DomainCategory, actual.DomainCategory, "Account DomainCategories should be equal")
|
|
assert.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount, "Account IsDomainPrimaryAccount flags should be equal")
|
|
assert.Equal(t, expected.Network, actual.Network, "Embedded Account Network structs should be equal")
|
|
assert.Equal(t, expected.DNSSettings, actual.DNSSettings, "Embedded Account DNSSettings structs should be equal")
|
|
assert.Equal(t, expected.Onboarding, actual.Onboarding, "Embedded Account Onboarding structs should be equal")
|
|
|
|
assert.Len(t, actual.SetupKeys, len(expected.SetupKeys), "SetupKeys maps should have the same number of elements")
|
|
for key, oldVal := range expected.SetupKeys {
|
|
newVal, ok := actual.SetupKeys[key]
|
|
assert.True(t, ok, "SetupKey with key '%s' should exist in new account", key)
|
|
assert.Equal(t, *oldVal, *newVal, "SetupKey with key '%s' should be equal", key)
|
|
}
|
|
|
|
assert.Len(t, actual.Peers, len(expected.Peers), "Peers maps should have the same number of elements")
|
|
for key, oldVal := range expected.Peers {
|
|
newVal, ok := actual.Peers[key]
|
|
assert.True(t, ok, "Peer with ID '%s' should exist in new account", key)
|
|
assert.Equal(t, *oldVal, *newVal, "Peer with ID '%s' should be equal", key)
|
|
}
|
|
|
|
assert.Len(t, actual.Users, len(expected.Users), "Users maps should have the same number of elements")
|
|
for key, oldUser := range expected.Users {
|
|
newUser, ok := actual.Users[key]
|
|
assert.True(t, ok, "User with ID '%s' should exist in new account", key)
|
|
|
|
assert.Len(t, newUser.PATs, len(oldUser.PATs), "PATs map for user '%s' should have the same size", key)
|
|
for patKey, oldPAT := range oldUser.PATs {
|
|
newPAT, patOk := newUser.PATs[patKey]
|
|
assert.True(t, patOk, "PAT with ID '%s' for user '%s' should exist in new user object", patKey, key)
|
|
assert.Equal(t, *oldPAT, *newPAT, "PAT with ID '%s' for user '%s' should be equal", patKey, key)
|
|
}
|
|
|
|
oldUser.PATs = nil
|
|
newUser.PATs = nil
|
|
assert.Equal(t, *oldUser, *newUser, "User struct for ID '%s' (without PATs) should be equal", key)
|
|
}
|
|
|
|
assert.Len(t, actual.Groups, len(expected.Groups), "Groups maps should have the same number of elements")
|
|
for key, oldVal := range expected.Groups {
|
|
newVal, ok := actual.Groups[key]
|
|
assert.True(t, ok, "Group with ID '%s' should exist in new account", key)
|
|
sort.Strings(oldVal.Peers)
|
|
sort.Strings(newVal.Peers)
|
|
assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key)
|
|
}
|
|
|
|
assert.Len(t, actual.Routes, len(expected.Routes), "Routes maps should have the same number of elements")
|
|
for key, oldVal := range expected.Routes {
|
|
newVal, ok := actual.Routes[key]
|
|
assert.True(t, ok, "Route with ID '%s' should exist in new account", key)
|
|
assert.Equal(t, *oldVal, *newVal, "Route with ID '%s' should be equal", key)
|
|
}
|
|
|
|
assert.Len(t, actual.NameServerGroups, len(expected.NameServerGroups), "NameServerGroups maps should have the same number of elements")
|
|
for key, oldVal := range expected.NameServerGroups {
|
|
newVal, ok := actual.NameServerGroups[key]
|
|
assert.True(t, ok, "NameServerGroup with ID '%s' should exist in new account", key)
|
|
assert.Equal(t, *oldVal, *newVal, "NameServerGroup with ID '%s' should be equal", key)
|
|
}
|
|
|
|
assert.Len(t, actual.Policies, len(expected.Policies), "Policies slices should have the same number of elements")
|
|
sort.Slice(expected.Policies, func(i, j int) bool { return expected.Policies[i].ID < expected.Policies[j].ID })
|
|
sort.Slice(actual.Policies, func(i, j int) bool { return actual.Policies[i].ID < actual.Policies[j].ID })
|
|
for i := range expected.Policies {
|
|
sort.Slice(expected.Policies[i].Rules, func(j, k int) bool { return expected.Policies[i].Rules[j].ID < expected.Policies[i].Rules[k].ID })
|
|
sort.Slice(actual.Policies[i].Rules, func(j, k int) bool { return actual.Policies[i].Rules[j].ID < actual.Policies[i].Rules[k].ID })
|
|
assert.Equal(t, *expected.Policies[i], *actual.Policies[i], "Policy with ID '%s' should be equal", expected.Policies[i].ID)
|
|
}
|
|
|
|
assert.Len(t, actual.PostureChecks, len(expected.PostureChecks), "PostureChecks slices should have the same number of elements")
|
|
sort.Slice(expected.PostureChecks, func(i, j int) bool { return expected.PostureChecks[i].ID < expected.PostureChecks[j].ID })
|
|
sort.Slice(actual.PostureChecks, func(i, j int) bool { return actual.PostureChecks[i].ID < actual.PostureChecks[j].ID })
|
|
for i := range expected.PostureChecks {
|
|
assert.Equal(t, *expected.PostureChecks[i], *actual.PostureChecks[i], "PostureCheck with ID '%s' should be equal", expected.PostureChecks[i].ID)
|
|
}
|
|
|
|
assert.Len(t, actual.Networks, len(expected.Networks), "Networks slices should have the same number of elements")
|
|
sort.Slice(expected.Networks, func(i, j int) bool { return expected.Networks[i].ID < expected.Networks[j].ID })
|
|
sort.Slice(actual.Networks, func(i, j int) bool { return actual.Networks[i].ID < actual.Networks[j].ID })
|
|
for i := range expected.Networks {
|
|
assert.Equal(t, *expected.Networks[i], *actual.Networks[i], "Network with ID '%s' should be equal", expected.Networks[i].ID)
|
|
}
|
|
|
|
assert.Len(t, actual.NetworkRouters, len(expected.NetworkRouters), "NetworkRouters slices should have the same number of elements")
|
|
sort.Slice(expected.NetworkRouters, func(i, j int) bool { return expected.NetworkRouters[i].ID < expected.NetworkRouters[j].ID })
|
|
sort.Slice(actual.NetworkRouters, func(i, j int) bool { return actual.NetworkRouters[i].ID < actual.NetworkRouters[j].ID })
|
|
for i := range expected.NetworkRouters {
|
|
assert.Equal(t, *expected.NetworkRouters[i], *actual.NetworkRouters[i], "NetworkRouter with ID '%s' should be equal", expected.NetworkRouters[i].ID)
|
|
}
|
|
|
|
assert.Len(t, actual.NetworkResources, len(expected.NetworkResources), "NetworkResources slices should have the same number of elements")
|
|
sort.Slice(expected.NetworkResources, func(i, j int) bool { return expected.NetworkResources[i].ID < expected.NetworkResources[j].ID })
|
|
sort.Slice(actual.NetworkResources, func(i, j int) bool { return actual.NetworkResources[i].ID < actual.NetworkResources[j].ID })
|
|
for i := range expected.NetworkResources {
|
|
assert.Equal(t, *expected.NetworkResources[i], *actual.NetworkResources[i], "NetworkResource with ID '%s' should be equal", expected.NetworkResources[i].ID)
|
|
}
|
|
}
|
|
|
|
func normalizeNilSlices(acc *types.Account) {
|
|
if acc == nil {
|
|
return
|
|
}
|
|
|
|
if acc.Policies == nil {
|
|
acc.Policies = []*types.Policy{}
|
|
}
|
|
if acc.PostureChecks == nil {
|
|
acc.PostureChecks = []*posture.Checks{}
|
|
}
|
|
if acc.Networks == nil {
|
|
acc.Networks = []*networkTypes.Network{}
|
|
}
|
|
if acc.NetworkRouters == nil {
|
|
acc.NetworkRouters = []*routerTypes.NetworkRouter{}
|
|
}
|
|
if acc.NetworkResources == nil {
|
|
acc.NetworkResources = []*resourceTypes.NetworkResource{}
|
|
}
|
|
if acc.DNSSettings.DisabledManagementGroups == nil {
|
|
acc.DNSSettings.DisabledManagementGroups = []string{}
|
|
}
|
|
|
|
for _, key := range acc.SetupKeys {
|
|
if key.AutoGroups == nil {
|
|
key.AutoGroups = []string{}
|
|
}
|
|
}
|
|
|
|
for _, peer := range acc.Peers {
|
|
if peer.ExtraDNSLabels == nil {
|
|
peer.ExtraDNSLabels = []string{}
|
|
}
|
|
}
|
|
|
|
for _, user := range acc.Users {
|
|
if user.AutoGroups == nil {
|
|
user.AutoGroups = []string{}
|
|
}
|
|
}
|
|
|
|
for _, group := range acc.Groups {
|
|
if group.Peers == nil {
|
|
group.Peers = []string{}
|
|
}
|
|
if group.Resources == nil {
|
|
group.Resources = []types.Resource{}
|
|
}
|
|
if group.GroupPeers == nil {
|
|
group.GroupPeers = []types.GroupPeer{}
|
|
}
|
|
}
|
|
|
|
for _, route := range acc.Routes {
|
|
if route.Domains == nil {
|
|
route.Domains = domain.List{}
|
|
}
|
|
if route.PeerGroups == nil {
|
|
route.PeerGroups = []string{}
|
|
}
|
|
if route.Groups == nil {
|
|
route.Groups = []string{}
|
|
}
|
|
if route.AccessControlGroups == nil {
|
|
route.AccessControlGroups = []string{}
|
|
}
|
|
}
|
|
|
|
for _, nsg := range acc.NameServerGroups {
|
|
if nsg.NameServers == nil {
|
|
nsg.NameServers = []nbdns.NameServer{}
|
|
}
|
|
if nsg.Groups == nil {
|
|
nsg.Groups = []string{}
|
|
}
|
|
if nsg.Domains == nil {
|
|
nsg.Domains = []string{}
|
|
}
|
|
}
|
|
|
|
for _, policy := range acc.Policies {
|
|
if policy.SourcePostureChecks == nil {
|
|
policy.SourcePostureChecks = []string{}
|
|
}
|
|
if policy.Rules == nil {
|
|
policy.Rules = []*types.PolicyRule{}
|
|
}
|
|
for _, rule := range policy.Rules {
|
|
if rule.Destinations == nil {
|
|
rule.Destinations = []string{}
|
|
}
|
|
if rule.Sources == nil {
|
|
rule.Sources = []string{}
|
|
}
|
|
if rule.Ports == nil {
|
|
rule.Ports = []string{}
|
|
}
|
|
if rule.PortRanges == nil {
|
|
rule.PortRanges = []types.RulePortRange{}
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, check := range acc.PostureChecks {
|
|
if check.Checks.GeoLocationCheck != nil {
|
|
if check.Checks.GeoLocationCheck.Locations == nil {
|
|
check.Checks.GeoLocationCheck.Locations = []posture.Location{}
|
|
}
|
|
}
|
|
if check.Checks.PeerNetworkRangeCheck != nil {
|
|
if check.Checks.PeerNetworkRangeCheck.Ranges == nil {
|
|
check.Checks.PeerNetworkRangeCheck.Ranges = []netip.Prefix{}
|
|
}
|
|
}
|
|
if check.Checks.ProcessCheck != nil {
|
|
if check.Checks.ProcessCheck.Processes == nil {
|
|
check.Checks.ProcessCheck.Processes = []posture.Process{}
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, router := range acc.NetworkRouters {
|
|
if router.PeerGroups == nil {
|
|
router.PeerGroups = []string{}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestGetAccountEquals(t *testing.T) {
|
|
store, accountID := setupBenchmarkDB(t)
|
|
ctx := context.Background()
|
|
expected, err := store.GetAccountSlow(ctx, accountID)
|
|
require.NoError(t, err)
|
|
actual, err := store.GetAccount(ctx, accountID)
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, expected.DNSSettings, actual.DNSSettings)
|
|
require.Equal(t, expected.Domain, actual.Domain)
|
|
require.Equal(t, expected.DomainCategory, actual.DomainCategory)
|
|
require.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount)
|
|
require.Equal(t, expected.Id, actual.Id)
|
|
require.Equal(t, expected.CreatedBy, actual.CreatedBy)
|
|
require.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second)
|
|
|
|
require.Equal(t, len(expected.SetupKeys), len(actual.SetupKeys))
|
|
for k, v := range expected.SetupKeys {
|
|
v2, ok := actual.SetupKeys[k]
|
|
require.True(t, ok)
|
|
require.Equal(t, v, v2)
|
|
}
|
|
|
|
require.Equal(t, len(expected.Peers), len(actual.Peers))
|
|
for k, v := range expected.Peers {
|
|
v2, ok := actual.Peers[k]
|
|
require.True(t, ok)
|
|
require.Equal(t, v, v2)
|
|
}
|
|
|
|
require.Equal(t, len(expected.Users), len(actual.Users))
|
|
for k, v := range expected.Users {
|
|
v2, ok := actual.Users[k]
|
|
require.True(t, ok)
|
|
require.Equal(t, v, v2)
|
|
require.Equal(t, len(v.PATs), len(v2.PATs))
|
|
for k3, v3 := range v.PATs {
|
|
v4, ok := v2.PATs[k3]
|
|
require.True(t, ok)
|
|
require.Equal(t, v3, v4)
|
|
}
|
|
}
|
|
require.Equal(t, len(expected.Groups), len(actual.Groups))
|
|
for k, v := range expected.Groups {
|
|
v2, ok := actual.Groups[k]
|
|
require.True(t, ok)
|
|
require.Equal(t, v, v2)
|
|
}
|
|
require.Equal(t, len(expected.Routes), len(actual.Routes))
|
|
for k, v := range expected.Routes {
|
|
v2, ok := actual.Routes[k]
|
|
require.True(t, ok)
|
|
require.Equal(t, v, v2)
|
|
}
|
|
require.Equal(t, len(expected.NameServerGroups), len(actual.NameServerGroups))
|
|
for k, v := range expected.NameServerGroups {
|
|
v2, ok := actual.NameServerGroups[k]
|
|
require.True(t, ok)
|
|
require.Equal(t, v, v2)
|
|
}
|
|
|
|
require.Equal(t, expected.Policies, actual.Policies)
|
|
|
|
require.Equal(t, expected.PostureChecks, actual.PostureChecks)
|
|
|
|
require.Equal(t, expected.Network, actual.Network)
|
|
|
|
require.Equal(t, expected.Networks, actual.Networks)
|
|
require.Equal(t, expected.NetworkRouters, actual.NetworkRouters)
|
|
require.Equal(t, expected.NetworkResources, actual.NetworkResources)
|
|
require.Equal(t, expected.Onboarding, actual.Onboarding)
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) {
|
|
var account types.Account
|
|
account.Network = &types.Network{}
|
|
const accountQuery = `
|
|
SELECT
|
|
id, created_by, created_at, domain, domain_category, is_domain_primary_account,
|
|
-- Embedded Network
|
|
network_identifier, network_net, network_dns, network_serial,
|
|
-- Embedded DNSSettings
|
|
dns_settings_disabled_management_groups,
|
|
-- Embedded Settings
|
|
settings_peer_login_expiration_enabled, settings_peer_login_expiration,
|
|
settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration,
|
|
settings_regular_users_view_blocked, settings_groups_propagation_enabled,
|
|
settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups,
|
|
settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range,
|
|
settings_lazy_connection_enabled,
|
|
-- Embedded ExtraSettings
|
|
settings_extra_peer_approval_enabled, settings_extra_user_approval_required,
|
|
settings_extra_integrated_validator, settings_extra_integrated_validator_groups
|
|
FROM accounts WHERE id = $1`
|
|
|
|
var networkNet, dnsSettingsDisabledGroups []byte
|
|
var (
|
|
sPeerLoginExpirationEnabled sql.NullBool
|
|
sPeerLoginExpiration sql.NullInt64
|
|
sPeerInactivityExpirationEnabled sql.NullBool
|
|
sPeerInactivityExpiration sql.NullInt64
|
|
sRegularUsersViewBlocked sql.NullBool
|
|
sGroupsPropagationEnabled sql.NullBool
|
|
sJWTGroupsEnabled sql.NullBool
|
|
sJWTGroupsClaimName sql.NullString
|
|
sJWTAllowGroups []byte
|
|
sRoutingPeerDNSResolutionEnabled sql.NullBool
|
|
sDNSDomain sql.NullString
|
|
sNetworkRange []byte
|
|
sLazyConnectionEnabled sql.NullBool
|
|
sExtraPeerApprovalEnabled sql.NullBool
|
|
sExtraUserApprovalRequired sql.NullBool
|
|
sExtraIntegratedValidator sql.NullString
|
|
sExtraIntegratedValidatorGroups []byte
|
|
)
|
|
|
|
err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan(
|
|
&account.Id, &account.CreatedBy, &account.CreatedAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount,
|
|
&account.Network.Identifier, &networkNet, &account.Network.Dns, &account.Network.Serial,
|
|
&dnsSettingsDisabledGroups,
|
|
&sPeerLoginExpirationEnabled, &sPeerLoginExpiration,
|
|
&sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration,
|
|
&sRegularUsersViewBlocked, &sGroupsPropagationEnabled,
|
|
&sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups,
|
|
&sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange,
|
|
&sLazyConnectionEnabled,
|
|
&sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired,
|
|
&sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups,
|
|
)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, errors.New("account not found")
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
_ = json.Unmarshal(networkNet, &account.Network.Net)
|
|
_ = json.Unmarshal(dnsSettingsDisabledGroups, &account.DNSSettings.DisabledManagementGroups)
|
|
|
|
account.Settings = &types.Settings{Extra: &types.ExtraSettings{}}
|
|
if sPeerLoginExpirationEnabled.Valid {
|
|
account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool
|
|
}
|
|
if sPeerLoginExpiration.Valid {
|
|
account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64)
|
|
}
|
|
if sPeerInactivityExpirationEnabled.Valid {
|
|
account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool
|
|
}
|
|
if sPeerInactivityExpiration.Valid {
|
|
account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64)
|
|
}
|
|
if sRegularUsersViewBlocked.Valid {
|
|
account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool
|
|
}
|
|
if sGroupsPropagationEnabled.Valid {
|
|
account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool
|
|
}
|
|
if sJWTGroupsEnabled.Valid {
|
|
account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool
|
|
}
|
|
if sJWTGroupsClaimName.Valid {
|
|
account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String
|
|
}
|
|
if sRoutingPeerDNSResolutionEnabled.Valid {
|
|
account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool
|
|
}
|
|
if sDNSDomain.Valid {
|
|
account.Settings.DNSDomain = sDNSDomain.String
|
|
}
|
|
if sLazyConnectionEnabled.Valid {
|
|
account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool
|
|
}
|
|
if sJWTAllowGroups != nil {
|
|
_ = json.Unmarshal(sJWTAllowGroups, &account.Settings.JWTAllowGroups)
|
|
}
|
|
if sNetworkRange != nil {
|
|
_ = json.Unmarshal(sNetworkRange, &account.Settings.NetworkRange)
|
|
}
|
|
|
|
if sExtraPeerApprovalEnabled.Valid {
|
|
account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool
|
|
}
|
|
if sExtraUserApprovalRequired.Valid {
|
|
account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool
|
|
}
|
|
if sExtraIntegratedValidator.Valid {
|
|
account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String
|
|
}
|
|
if sExtraIntegratedValidatorGroups != nil {
|
|
_ = json.Unmarshal(sExtraIntegratedValidatorGroups, &account.Settings.Extra.IntegratedValidatorGroups)
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
errChan := make(chan error, 12)
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
|
|
keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) {
|
|
var sk types.SetupKey
|
|
var autoGroups []byte
|
|
var expiresAt, updatedAt, lastUsed sql.NullTime
|
|
var revoked, ephemeral, allowExtraDNSLabels sql.NullBool
|
|
var usedTimes, usageLimit sql.NullInt64
|
|
|
|
err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &sk.CreatedAt, &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels)
|
|
|
|
if err == nil {
|
|
if expiresAt.Valid {
|
|
sk.ExpiresAt = &expiresAt.Time
|
|
}
|
|
if updatedAt.Valid {
|
|
sk.UpdatedAt = updatedAt.Time
|
|
if sk.UpdatedAt.IsZero() {
|
|
sk.UpdatedAt = sk.CreatedAt
|
|
}
|
|
}
|
|
if lastUsed.Valid {
|
|
sk.LastUsed = &lastUsed.Time
|
|
}
|
|
if revoked.Valid {
|
|
sk.Revoked = revoked.Bool
|
|
}
|
|
if usedTimes.Valid {
|
|
sk.UsedTimes = int(usedTimes.Int64)
|
|
}
|
|
if usageLimit.Valid {
|
|
sk.UsageLimit = int(usageLimit.Int64)
|
|
}
|
|
if ephemeral.Valid {
|
|
sk.Ephemeral = ephemeral.Bool
|
|
}
|
|
if allowExtraDNSLabels.Valid {
|
|
sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool
|
|
}
|
|
if autoGroups != nil {
|
|
_ = json.Unmarshal(autoGroups, &sk.AutoGroups)
|
|
} else {
|
|
sk.AutoGroups = []string{}
|
|
}
|
|
}
|
|
return sk, err
|
|
})
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.SetupKeysG = keys
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, location_geo_name_id FROM peers WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
|
|
peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) {
|
|
var p nbpeer.Peer
|
|
p.Status = &nbpeer.PeerStatus{}
|
|
var ip, extraDNS, netAddr, env, flags, files, connIP []byte
|
|
err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &p.SSHEnabled, &p.LoginExpirationEnabled, &p.InactivityExpirationEnabled, &p.LastLogin, &p.CreatedAt, &p.Ephemeral, &extraDNS, &p.AllowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &p.Status.LastSeen, &p.Status.Connected, &p.Status.LoginExpired, &p.Status.RequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID)
|
|
if err == nil {
|
|
if ip != nil {
|
|
_ = json.Unmarshal(ip, &p.IP)
|
|
}
|
|
if extraDNS != nil {
|
|
_ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels)
|
|
}
|
|
if netAddr != nil {
|
|
_ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses)
|
|
}
|
|
if env != nil {
|
|
_ = json.Unmarshal(env, &p.Meta.Environment)
|
|
}
|
|
if flags != nil {
|
|
_ = json.Unmarshal(flags, &p.Meta.Flags)
|
|
}
|
|
if files != nil {
|
|
_ = json.Unmarshal(files, &p.Meta.Files)
|
|
}
|
|
if connIP != nil {
|
|
_ = json.Unmarshal(connIP, &p.Location.ConnectionIP)
|
|
}
|
|
}
|
|
return p, err
|
|
})
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.PeersG = peers
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) {
|
|
var u types.User
|
|
var autoGroups []byte
|
|
var lastLogin sql.NullTime
|
|
var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
|
|
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &u.CreatedAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType)
|
|
if err == nil {
|
|
if lastLogin.Valid {
|
|
u.LastLogin = &lastLogin.Time
|
|
}
|
|
if isServiceUser.Valid {
|
|
u.IsServiceUser = isServiceUser.Bool
|
|
}
|
|
if nonDeletable.Valid {
|
|
u.NonDeletable = nonDeletable.Bool
|
|
}
|
|
if blocked.Valid {
|
|
u.Blocked = blocked.Bool
|
|
}
|
|
if pendingApproval.Valid {
|
|
u.PendingApproval = pendingApproval.Bool
|
|
}
|
|
if autoGroups != nil {
|
|
_ = json.Unmarshal(autoGroups, &u.AutoGroups)
|
|
} else {
|
|
u.AutoGroups = []string{}
|
|
}
|
|
}
|
|
return u, err
|
|
})
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.UsersG = users
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) {
|
|
var g types.Group
|
|
var resources []byte
|
|
var refID sql.NullInt64
|
|
var refType sql.NullString
|
|
err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType)
|
|
if err == nil {
|
|
if refID.Valid {
|
|
g.IntegrationReference.ID = int(refID.Int64)
|
|
}
|
|
if refType.Valid {
|
|
g.IntegrationReference.IntegrationType = refType.String
|
|
}
|
|
if resources != nil {
|
|
_ = json.Unmarshal(resources, &g.Resources)
|
|
} else {
|
|
g.Resources = []types.Resource{}
|
|
}
|
|
g.GroupPeers = []types.GroupPeer{}
|
|
g.Peers = []string{}
|
|
}
|
|
return &g, err
|
|
})
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.GroupsG = groups
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) {
|
|
var p types.Policy
|
|
var checks []byte
|
|
err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &p.Enabled, &checks)
|
|
if err == nil && checks != nil {
|
|
_ = json.Unmarshal(checks, &p.SourcePostureChecks)
|
|
}
|
|
return &p, err
|
|
})
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.Policies = policies
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) {
|
|
var r route.Route
|
|
var network, domains, peerGroups, groups, accessGroups []byte
|
|
err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &r.KeepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &r.Masquerade, &r.Metric, &r.Enabled, &groups, &accessGroups, &r.SkipAutoApply)
|
|
if err == nil {
|
|
if network != nil {
|
|
_ = json.Unmarshal(network, &r.Network)
|
|
}
|
|
if domains != nil {
|
|
_ = json.Unmarshal(domains, &r.Domains)
|
|
}
|
|
if peerGroups != nil {
|
|
_ = json.Unmarshal(peerGroups, &r.PeerGroups)
|
|
}
|
|
if groups != nil {
|
|
_ = json.Unmarshal(groups, &r.Groups)
|
|
}
|
|
if accessGroups != nil {
|
|
_ = json.Unmarshal(accessGroups, &r.AccessControlGroups)
|
|
}
|
|
}
|
|
return r, err
|
|
})
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.RoutesG = routes
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) {
|
|
var n nbdns.NameServerGroup
|
|
var ns, groups, domains []byte
|
|
err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &n.Primary, &domains, &n.Enabled, &n.SearchDomainsEnabled)
|
|
if err == nil {
|
|
if ns != nil {
|
|
_ = json.Unmarshal(ns, &n.NameServers)
|
|
} else {
|
|
n.NameServers = []nbdns.NameServer{}
|
|
}
|
|
if groups != nil {
|
|
_ = json.Unmarshal(groups, &n.Groups)
|
|
} else {
|
|
n.Groups = []string{}
|
|
}
|
|
if domains != nil {
|
|
_ = json.Unmarshal(domains, &n.Domains)
|
|
} else {
|
|
n.Domains = []string{}
|
|
}
|
|
}
|
|
return n, err
|
|
})
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.NameServerGroupsG = nsgs
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) {
|
|
var c posture.Checks
|
|
var checksDef []byte
|
|
err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef)
|
|
if err == nil && checksDef != nil {
|
|
_ = json.Unmarshal(checksDef, &c.Checks)
|
|
}
|
|
return &c, err
|
|
})
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.PostureChecks = checks
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network])
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.Networks = make([]*networkTypes.Network, len(networks))
|
|
for i := range networks {
|
|
account.Networks[i] = &networks[i]
|
|
}
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*routerTypes.NetworkRouter, error) {
|
|
var r routerTypes.NetworkRouter
|
|
var peerGroups []byte
|
|
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &r.Masquerade, &r.Metric, &r.Enabled)
|
|
if err == nil && peerGroups != nil {
|
|
_ = json.Unmarshal(peerGroups, &r.PeerGroups)
|
|
}
|
|
return &r, err
|
|
})
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.NetworkRouters = routers
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*resourceTypes.NetworkResource, error) {
|
|
var r resourceTypes.NetworkResource
|
|
var prefix []byte
|
|
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &r.Enabled)
|
|
if err == nil && prefix != nil {
|
|
_ = json.Unmarshal(prefix, &r.Prefix)
|
|
}
|
|
return &r, err
|
|
})
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.NetworkResources = resources
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1`
|
|
err := s.pool.QueryRow(ctx, query, accountID).Scan(
|
|
&account.Onboarding.AccountID,
|
|
&account.Onboarding.OnboardingFlowPending,
|
|
&account.Onboarding.SignupFormPending,
|
|
&account.Onboarding.CreatedAt,
|
|
&account.Onboarding.UpdatedAt,
|
|
)
|
|
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
|
errChan <- err
|
|
}
|
|
}()
|
|
|
|
wg.Wait()
|
|
close(errChan)
|
|
for e := range errChan {
|
|
if e != nil {
|
|
return nil, e
|
|
}
|
|
}
|
|
|
|
var userIDs []string
|
|
for _, u := range account.UsersG {
|
|
userIDs = append(userIDs, u.Id)
|
|
}
|
|
var policyIDs []string
|
|
for _, p := range account.Policies {
|
|
policyIDs = append(policyIDs, p.ID)
|
|
}
|
|
var groupIDs []string
|
|
for _, g := range account.GroupsG {
|
|
groupIDs = append(groupIDs, g.ID)
|
|
}
|
|
|
|
wg.Add(3)
|
|
errChan = make(chan error, 3)
|
|
|
|
var pats []types.PersonalAccessToken
|
|
go func() {
|
|
defer wg.Done()
|
|
if len(userIDs) == 0 {
|
|
return
|
|
}
|
|
const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)`
|
|
rows, err := s.pool.Query(ctx, query, userIDs)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
pats, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.PersonalAccessToken])
|
|
if err != nil {
|
|
errChan <- err
|
|
}
|
|
}()
|
|
|
|
var rules []*types.PolicyRule
|
|
go func() {
|
|
defer wg.Done()
|
|
if len(policyIDs) == 0 {
|
|
return
|
|
}
|
|
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)`
|
|
rows, err := s.pool.Query(ctx, query, policyIDs)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
rules, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) {
|
|
var r types.PolicyRule
|
|
var dest, destRes, sources, sourceRes, ports, portRanges []byte
|
|
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &r.Enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &r.Bidirectional, &r.Protocol, &ports, &portRanges)
|
|
if err == nil {
|
|
if dest != nil {
|
|
_ = json.Unmarshal(dest, &r.Destinations)
|
|
}
|
|
if destRes != nil {
|
|
_ = json.Unmarshal(destRes, &r.DestinationResource)
|
|
}
|
|
if sources != nil {
|
|
_ = json.Unmarshal(sources, &r.Sources)
|
|
}
|
|
if sourceRes != nil {
|
|
_ = json.Unmarshal(sourceRes, &r.SourceResource)
|
|
}
|
|
if ports != nil {
|
|
_ = json.Unmarshal(ports, &r.Ports)
|
|
}
|
|
if portRanges != nil {
|
|
_ = json.Unmarshal(portRanges, &r.PortRanges)
|
|
}
|
|
}
|
|
return &r, err
|
|
})
|
|
if err != nil {
|
|
errChan <- err
|
|
}
|
|
}()
|
|
|
|
var groupPeers []types.GroupPeer
|
|
go func() {
|
|
defer wg.Done()
|
|
if len(groupIDs) == 0 {
|
|
return
|
|
}
|
|
const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)`
|
|
rows, err := s.pool.Query(ctx, query, groupIDs)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
groupPeers, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer])
|
|
if err != nil {
|
|
errChan <- err
|
|
}
|
|
}()
|
|
|
|
wg.Wait()
|
|
close(errChan)
|
|
for e := range errChan {
|
|
if e != nil {
|
|
return nil, e
|
|
}
|
|
}
|
|
|
|
patsByUserID := make(map[string][]*types.PersonalAccessToken)
|
|
for i := range pats {
|
|
pat := &pats[i]
|
|
patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat)
|
|
pat.UserID = ""
|
|
}
|
|
|
|
rulesByPolicyID := make(map[string][]*types.PolicyRule)
|
|
for _, rule := range rules {
|
|
rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule)
|
|
}
|
|
|
|
peersByGroupID := make(map[string][]string)
|
|
for _, gp := range groupPeers {
|
|
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
|
|
}
|
|
|
|
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
|
|
for i := range account.SetupKeysG {
|
|
key := &account.SetupKeysG[i]
|
|
account.SetupKeys[key.Key] = key
|
|
}
|
|
|
|
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
|
|
for i := range account.PeersG {
|
|
peer := &account.PeersG[i]
|
|
account.Peers[peer.ID] = peer
|
|
}
|
|
|
|
account.Users = make(map[string]*types.User, len(account.UsersG))
|
|
for i := range account.UsersG {
|
|
user := &account.UsersG[i]
|
|
user.PATs = make(map[string]*types.PersonalAccessToken)
|
|
if userPats, ok := patsByUserID[user.Id]; ok {
|
|
for j := range userPats {
|
|
pat := userPats[j]
|
|
user.PATs[pat.ID] = pat
|
|
}
|
|
}
|
|
account.Users[user.Id] = user
|
|
}
|
|
|
|
for i := range account.Policies {
|
|
policy := account.Policies[i]
|
|
if policyRules, ok := rulesByPolicyID[policy.ID]; ok {
|
|
policy.Rules = policyRules
|
|
}
|
|
}
|
|
|
|
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
|
|
for i := range account.GroupsG {
|
|
group := account.GroupsG[i]
|
|
if peerIDs, ok := peersByGroupID[group.ID]; ok {
|
|
group.Peers = peerIDs
|
|
}
|
|
account.Groups[group.ID] = group
|
|
}
|
|
|
|
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
|
|
for i := range account.RoutesG {
|
|
route := &account.RoutesG[i]
|
|
account.Routes[route.ID] = route
|
|
}
|
|
|
|
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
|
|
for i := range account.NameServerGroupsG {
|
|
nsg := &account.NameServerGroupsG[i]
|
|
nsg.AccountID = ""
|
|
account.NameServerGroups[nsg.ID] = nsg
|
|
}
|
|
|
|
account.SetupKeysG = nil
|
|
account.PeersG = nil
|
|
account.UsersG = nil
|
|
account.GroupsG = nil
|
|
account.RoutesG = nil
|
|
account.NameServerGroupsG = nil
|
|
|
|
return &account, nil
|
|
}
|