mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-24 19:26:39 +00:00
Trim the fast-path Sync handler by removing two DB round trips on cache hit:
1. Consolidate GetUserIDByPeerKey + GetAccountIDByPeerPubKey into a single
GetPeerAuthInfoByPubKey store call. Both looked up the same peer row by
pubkey and returned one column each; the new method SELECTs both columns
in one query. AccountManager exposes it as GetPeerAuthInfo.
2. Extend peerSyncEntry with AccountID, PeerID, PeerKey, Ephemeral and a
HasUser flag so the cache carries everything the fast path needs. On
cache hit with a matching metaHash:
- The Sync handler skips GetPeerAuthInfo entirely (entry.AccountID and
entry.HasUser drive the loginFilter gate).
- commitFastPath skips GetPeerByPeerPubKey by using the cached peer
snapshot for OnPeerConnectedWithPeer.
Old cache entries from pre-step-2 shape still decode (missing fields zero
out) but IsComplete() returns false, so they fall through to the slow path
and get rewritten with the full shape on first pass. No migration needed.
Expected impact on a 16.8 s pathological Sync observed in production:
~6 s saved from eliminating one auth-read round trip, the pre-fast-path
GetPeerAuthInfo on cache hit, and GetPeerByPeerPubKey in commitFastPath.
Cache miss / cold start remain on the slow path unchanged.
Account-serial, ExtraSettings and peer-group caching — the remaining
synchronous DB reads — are deliberately left for a follow-up so the
invalidation design can be proven incrementally.
5725 lines
188 KiB
Go
5725 lines
188 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"runtime/debug"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"github.com/rs/xid"
|
|
log "github.com/sirupsen/logrus"
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
"gorm.io/gorm/logger"
|
|
|
|
nbdns "github.com/netbirdio/netbird/dns"
|
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
|
|
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
|
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
"github.com/netbirdio/netbird/management/server/posture"
|
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
"github.com/netbirdio/netbird/management/server/types"
|
|
"github.com/netbirdio/netbird/management/server/util"
|
|
"github.com/netbirdio/netbird/route"
|
|
"github.com/netbirdio/netbird/shared/management/status"
|
|
"github.com/netbirdio/netbird/util/crypt"
|
|
)
|
|
|
|
const (
|
|
storeSqliteFileName = "store.db"
|
|
idQueryCondition = "id = ?"
|
|
keyQueryCondition = "key = ?"
|
|
mysqlKeyQueryCondition = "`key` = ?"
|
|
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
|
accountAndPeerIDQueryCondition = "account_id = ? and peer_id = ?"
|
|
accountAndIDsQueryCondition = "account_id = ? AND id IN ?"
|
|
accountIDCondition = "account_id = ?"
|
|
peerNotFoundFMT = "peer %s not found"
|
|
|
|
pgMaxConnections = 30
|
|
pgMinConnections = 1
|
|
pgMaxConnLifetime = 60 * time.Minute
|
|
pgHealthCheckPeriod = 1 * time.Minute
|
|
)
|
|
|
|
// SqlStore represents an account storage backed by a Sql DB persisted to disk
|
|
type SqlStore struct {
|
|
db *gorm.DB
|
|
globalAccountLock sync.Mutex
|
|
metrics telemetry.AppMetrics
|
|
installationPK int
|
|
storeEngine types.Engine
|
|
pool *pgxpool.Pool
|
|
fieldEncrypt *crypt.FieldEncrypt
|
|
transactionTimeout time.Duration
|
|
}
|
|
|
|
type installation struct {
|
|
ID uint `gorm:"primaryKey"`
|
|
InstallationIDValue string
|
|
}
|
|
|
|
type migrationFunc func(*gorm.DB) error
|
|
|
|
// NewSqlStore creates a new SqlStore instance.
|
|
func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) {
|
|
sql, err := db.DB()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
conns, err := strconv.Atoi(os.Getenv("NB_SQL_MAX_OPEN_CONNS"))
|
|
if err != nil {
|
|
conns = runtime.NumCPU()
|
|
}
|
|
|
|
transactionTimeout := 5 * time.Minute
|
|
if v := os.Getenv("NB_STORE_TRANSACTION_TIMEOUT"); v != "" {
|
|
if parsed, err := time.ParseDuration(v); err == nil {
|
|
transactionTimeout = parsed
|
|
}
|
|
}
|
|
log.WithContext(ctx).Infof("Setting transaction timeout to %v", transactionTimeout)
|
|
|
|
if storeEngine == types.SqliteStoreEngine {
|
|
if err == nil {
|
|
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
|
|
}
|
|
conns = 1
|
|
}
|
|
|
|
sql.SetMaxOpenConns(conns)
|
|
sql.SetMaxIdleConns(conns)
|
|
sql.SetConnMaxLifetime(time.Hour)
|
|
sql.SetConnMaxIdleTime(3 * time.Minute)
|
|
|
|
log.WithContext(ctx).Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v",
|
|
conns, conns, time.Hour, 3*time.Minute)
|
|
|
|
if skipMigration {
|
|
log.WithContext(ctx).Infof("skipping migration")
|
|
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1, transactionTimeout: transactionTimeout}, nil
|
|
}
|
|
|
|
if err := migratePreAuto(ctx, db); err != nil {
|
|
return nil, fmt.Errorf("migratePreAuto: %w", err)
|
|
}
|
|
err = db.AutoMigrate(
|
|
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.ProxyAccessToken{},
|
|
&types.Group{}, &types.GroupPeer{},
|
|
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
|
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
|
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
|
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
|
|
&accesslogs.AccessLogEntry{}, &proxy.Proxy{},
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
|
}
|
|
if err := migratePostAuto(ctx, db); err != nil {
|
|
return nil, fmt.Errorf("migratePostAuto: %w", err)
|
|
}
|
|
|
|
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1, transactionTimeout: transactionTimeout}, nil
|
|
}
|
|
|
|
func GetKeyQueryCondition(s *SqlStore) string {
|
|
if s.storeEngine == types.MysqlStoreEngine {
|
|
return mysqlKeyQueryCondition
|
|
}
|
|
return keyQueryCondition
|
|
}
|
|
|
|
// SaveJob persists a job in DB
|
|
func (s *SqlStore) CreatePeerJob(ctx context.Context, job *types.Job) error {
|
|
result := s.db.Create(job)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to create job in store: %s", result.Error)
|
|
return status.Errorf(status.Internal, "failed to create job in store")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) CompletePeerJob(ctx context.Context, job *types.Job) error {
|
|
result := s.db.
|
|
Model(&types.Job{}).
|
|
Where(idQueryCondition, job.ID).
|
|
Updates(job)
|
|
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to update job in store: %s", result.Error)
|
|
return status.Errorf(status.Internal, "failed to update job in store")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// job was pending for too long and has been cancelled
|
|
func (s *SqlStore) MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, jobID, reason string) error {
|
|
now := time.Now().UTC()
|
|
result := s.db.
|
|
Model(&types.Job{}).
|
|
Where(accountAndPeerIDQueryCondition+" AND id = ?"+" AND status = ?", accountID, peerID, jobID, types.JobStatusPending).
|
|
Updates(types.Job{
|
|
Status: types.JobStatusFailed,
|
|
FailedReason: reason,
|
|
CompletedAt: &now,
|
|
})
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to mark pending jobs as Failed job in store: %s", result.Error)
|
|
return status.Errorf(status.Internal, "failed to mark pending job as Failed in store")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// job was pending for too long and has been cancelled
|
|
func (s *SqlStore) MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error {
|
|
now := time.Now().UTC()
|
|
result := s.db.
|
|
Model(&types.Job{}).
|
|
Where(accountAndPeerIDQueryCondition+" AND status = ?", accountID, peerID, types.JobStatusPending).
|
|
Updates(types.Job{
|
|
Status: types.JobStatusFailed,
|
|
FailedReason: reason,
|
|
CompletedAt: &now,
|
|
})
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to mark pending jobs as Failed job in store: %s", result.Error)
|
|
return status.Errorf(status.Internal, "failed to mark pending job as Failed in store")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetJobByID fetches job by ID
|
|
func (s *SqlStore) GetPeerJobByID(ctx context.Context, accountID, jobID string) (*types.Job, error) {
|
|
var job types.Job
|
|
err := s.db.
|
|
Where(accountAndIDQueryCondition, accountID, jobID).
|
|
First(&job).Error
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "job %s not found", jobID)
|
|
}
|
|
if err != nil {
|
|
log.WithContext(ctx).Errorf("failed to fetch job from store: %s", err)
|
|
return nil, err
|
|
}
|
|
return &job, nil
|
|
}
|
|
|
|
// get all jobs
|
|
func (s *SqlStore) GetPeerJobs(ctx context.Context, accountID, peerID string) ([]*types.Job, error) {
|
|
var jobs []*types.Job
|
|
err := s.db.
|
|
Where(accountAndPeerIDQueryCondition, accountID, peerID).
|
|
Order("created_at DESC").
|
|
Find(&jobs).Error
|
|
|
|
if err != nil {
|
|
log.WithContext(ctx).Errorf("failed to fetch jobs from store: %s", err)
|
|
return nil, err
|
|
}
|
|
|
|
return jobs, nil
|
|
}
|
|
|
|
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
|
|
func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
|
|
log.WithContext(ctx).Tracef("acquiring global lock")
|
|
start := time.Now()
|
|
s.globalAccountLock.Lock()
|
|
|
|
unlock = func() {
|
|
s.globalAccountLock.Unlock()
|
|
log.WithContext(ctx).Tracef("released global lock in %v", time.Since(start))
|
|
}
|
|
|
|
took := time.Since(start)
|
|
log.WithContext(ctx).Tracef("took %v to acquire global lock", took)
|
|
if s.metrics != nil {
|
|
s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took)
|
|
}
|
|
|
|
return unlock
|
|
}
|
|
|
|
// Deprecated: Full account operations are no longer supported
|
|
func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error {
|
|
start := time.Now()
|
|
defer func() {
|
|
elapsed := time.Since(start)
|
|
if elapsed > 1*time.Second {
|
|
log.WithContext(ctx).Tracef("SaveAccount for account %s exceeded 1s, took: %v", account.Id, elapsed)
|
|
}
|
|
}()
|
|
|
|
// todo: remove this check after the issue is resolved
|
|
s.checkAccountDomainBeforeSave(ctx, account.Id, account.Domain)
|
|
|
|
generateAccountSQLTypes(account)
|
|
|
|
// Encrypt sensitive user data before saving
|
|
for i := range account.UsersG {
|
|
if err := account.UsersG[i].EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return fmt.Errorf("encrypt user: %w", err)
|
|
}
|
|
}
|
|
|
|
for _, group := range account.GroupsG {
|
|
group.StoreGroupPeers()
|
|
}
|
|
|
|
err := s.transaction(func(tx *gorm.DB) error {
|
|
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
result = tx.Select(clause.Associations).Delete(account)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
result = tx.
|
|
Session(&gorm.Session{FullSaveAssociations: true}).
|
|
Clauses(clause.OnConflict{UpdateAll: true}).
|
|
Create(account)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
return nil
|
|
})
|
|
|
|
took := time.Since(start)
|
|
if s.metrics != nil {
|
|
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
|
}
|
|
log.WithContext(ctx).Debugf("took %d ms to persist an account to the store", took.Milliseconds())
|
|
|
|
return err
|
|
}
|
|
|
|
// generateAccountSQLTypes generates the GORM compatible types for the account
|
|
func generateAccountSQLTypes(account *types.Account) {
|
|
for _, key := range account.SetupKeys {
|
|
account.SetupKeysG = append(account.SetupKeysG, *key)
|
|
}
|
|
|
|
if len(account.SetupKeys) != len(account.SetupKeysG) {
|
|
log.Warnf("SetupKeysG length mismatch for account %s", account.Id)
|
|
}
|
|
|
|
for id, peer := range account.Peers {
|
|
peer.ID = id
|
|
account.PeersG = append(account.PeersG, *peer)
|
|
}
|
|
|
|
for id, user := range account.Users {
|
|
user.Id = id
|
|
for id, pat := range user.PATs {
|
|
pat.ID = id
|
|
user.PATsG = append(user.PATsG, *pat)
|
|
}
|
|
account.UsersG = append(account.UsersG, *user)
|
|
}
|
|
|
|
for id, group := range account.Groups {
|
|
group.ID = id
|
|
group.AccountID = account.Id
|
|
account.GroupsG = append(account.GroupsG, group)
|
|
}
|
|
|
|
for id, route := range account.Routes {
|
|
route.ID = id
|
|
account.RoutesG = append(account.RoutesG, *route)
|
|
}
|
|
|
|
for id, ns := range account.NameServerGroups {
|
|
ns.ID = id
|
|
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns)
|
|
}
|
|
}
|
|
|
|
// checkAccountDomainBeforeSave temporary method to troubleshoot an issue with domains getting blank
|
|
func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, newDomain string) {
|
|
var acc types.Account
|
|
var domain string
|
|
result := s.db.Model(&acc).Select("domain").Where(idQueryCondition, accountID).Take(&domain)
|
|
if result.Error != nil {
|
|
if !errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
log.WithContext(ctx).Errorf("error when getting account %s from the store to check domain: %s", accountID, result.Error)
|
|
}
|
|
return
|
|
}
|
|
if domain != "" && newDomain == "" {
|
|
log.WithContext(ctx).Warnf("saving an account with empty domain when there was a domain set. Previous domain %s, Account ID: %s, Trace: %s", domain, accountID, debug.Stack())
|
|
}
|
|
}
|
|
|
|
func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) error {
|
|
start := time.Now()
|
|
|
|
err := s.transaction(func(tx *gorm.DB) error {
|
|
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
result = tx.Select(clause.Associations).Delete(account.Services, "account_id = ?", account.Id)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
result = tx.Select(clause.Associations).Delete(account)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
took := time.Since(start)
|
|
if s.metrics != nil {
|
|
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
|
}
|
|
log.WithContext(ctx).Tracef("took %d ms to delete an account to the store", took.Milliseconds())
|
|
|
|
return err
|
|
}
|
|
|
|
func (s *SqlStore) SaveInstallationID(_ context.Context, ID string) error {
|
|
installation := installation{InstallationIDValue: ID}
|
|
installation.ID = uint(s.installationPK)
|
|
|
|
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&installation).Error
|
|
}
|
|
|
|
func (s *SqlStore) GetInstallationID() string {
|
|
var installation installation
|
|
|
|
if result := s.db.Take(&installation, idQueryCondition, s.installationPK); result.Error != nil {
|
|
return ""
|
|
}
|
|
|
|
return installation.InstallationIDValue
|
|
}
|
|
|
|
func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
|
|
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
|
|
peerCopy := peer.Copy()
|
|
peerCopy.AccountID = accountID
|
|
|
|
err := s.transaction(func(tx *gorm.DB) error {
|
|
// check if peer exists before saving
|
|
var peerID string
|
|
result := tx.Model(&nbpeer.Peer{}).Select("id").Take(&peerID, accountAndIDQueryCondition, accountID, peer.ID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return status.Errorf(status.NotFound, peerNotFoundFMT, peer.ID)
|
|
}
|
|
return result.Error
|
|
}
|
|
|
|
if peerID == "" {
|
|
return status.Errorf(status.NotFound, peerNotFoundFMT, peer.ID)
|
|
}
|
|
|
|
result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy)
|
|
if result.Error != nil {
|
|
return status.Errorf(status.Internal, "failed to save peer to store: %v", result.Error)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
|
|
accountCopy := types.Account{
|
|
Domain: domain,
|
|
DomainCategory: category,
|
|
IsDomainPrimaryAccount: isPrimaryDomain,
|
|
}
|
|
|
|
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
|
|
result := s.db.Model(&types.Account{}).
|
|
Select(fieldsToUpdate).
|
|
Where(idQueryCondition, accountID).
|
|
Updates(&accountCopy)
|
|
if result.Error != nil {
|
|
return status.Errorf(status.Internal, "failed to update account domain attributes to store: %v", result.Error)
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.Errorf(status.NotFound, "account %s", accountID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
|
var peerCopy nbpeer.Peer
|
|
peerCopy.Status = &peerStatus
|
|
|
|
fieldsToUpdate := []string{
|
|
"peer_status_last_seen", "peer_status_connected",
|
|
"peer_status_login_expired", "peer_status_required_approval",
|
|
}
|
|
result := s.db.Model(&nbpeer.Peer{}).
|
|
Select(fieldsToUpdate).
|
|
Where(accountAndIDQueryCondition, accountID, peerID).
|
|
Updates(&peerCopy)
|
|
if result.Error != nil {
|
|
return status.Errorf(status.Internal, "failed to save peer status to store: %v", result.Error)
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.Errorf(status.NotFound, peerNotFoundFMT, peerID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error {
|
|
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
|
|
var peerCopy nbpeer.Peer
|
|
// Since the location field has been migrated to JSON serialization,
|
|
// updating the struct ensures the correct data format is inserted into the database.
|
|
peerCopy.Location = peerWithLocation.Location
|
|
|
|
result := s.db.Model(&nbpeer.Peer{}).
|
|
Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
|
|
Updates(peerCopy)
|
|
|
|
if result.Error != nil {
|
|
return status.Errorf(status.Internal, "failed to save peer locations to store: %v", result.Error)
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ApproveAccountPeers marks all peers that currently require approval in the given account as approved.
|
|
func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) {
|
|
result := s.db.Model(&nbpeer.Peer{}).
|
|
Where("account_id = ? AND peer_status_requires_approval = ?", accountID, true).
|
|
Update("peer_status_requires_approval", false)
|
|
if result.Error != nil {
|
|
return 0, status.Errorf(status.Internal, "failed to approve pending account peers: %v", result.Error)
|
|
}
|
|
|
|
return int(result.RowsAffected), nil
|
|
}
|
|
|
|
// SaveUsers saves the given list of users to the database.
|
|
func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
|
|
if len(users) == 0 {
|
|
return nil
|
|
}
|
|
|
|
usersCopy := make([]*types.User, len(users))
|
|
for i, user := range users {
|
|
userCopy := user.Copy()
|
|
userCopy.Email = user.Email
|
|
userCopy.Name = user.Name
|
|
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return fmt.Errorf("encrypt user: %w", err)
|
|
}
|
|
usersCopy[i] = userCopy
|
|
}
|
|
|
|
result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&usersCopy)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to save users to store: %s", result.Error)
|
|
return status.Errorf(status.Internal, "failed to save users to store")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SaveUser saves the given user to the database.
|
|
func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
|
|
userCopy := user.Copy()
|
|
userCopy.Email = user.Email
|
|
userCopy.Name = user.Name
|
|
|
|
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return fmt.Errorf("encrypt user: %w", err)
|
|
}
|
|
|
|
result := s.db.Save(userCopy)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error)
|
|
return status.Errorf(status.Internal, "failed to save user to store")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CreateGroups creates the given list of groups to the database.
|
|
func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error {
|
|
if len(groups) == 0 {
|
|
return nil
|
|
}
|
|
|
|
return s.db.Transaction(func(tx *gorm.DB) error {
|
|
result := tx.
|
|
Clauses(
|
|
clause.OnConflict{
|
|
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
|
UpdateAll: true,
|
|
},
|
|
).
|
|
Omit(clause.Associations).
|
|
Create(&groups)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to save groups to store")
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// UpdateGroups updates the given list of groups to the database.
|
|
func (s *SqlStore) UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error {
|
|
if len(groups) == 0 {
|
|
return nil
|
|
}
|
|
|
|
return s.db.Transaction(func(tx *gorm.DB) error {
|
|
result := tx.
|
|
Clauses(
|
|
clause.OnConflict{
|
|
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
|
UpdateAll: true,
|
|
},
|
|
).
|
|
Omit(clause.Associations).
|
|
Create(&groups)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to save groups to store")
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
|
|
func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
|
|
return nil
|
|
}
|
|
|
|
// DeleteTokenID2UserIDIndex is noop in SqlStore
|
|
func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) {
|
|
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthNone, domain)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// TODO: rework to not call GetAccount
|
|
return s.GetAccount(ctx, accountID)
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var accountID string
|
|
result := tx.Model(&types.Account{}).Select("id").
|
|
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
|
strings.ToLower(domain), true, types.PrivateCategory,
|
|
).Take(&accountID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
|
}
|
|
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
|
|
return "", status.NewGetAccountFromStoreError(result.Error)
|
|
}
|
|
|
|
return accountID, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) {
|
|
var key types.SetupKey
|
|
result := s.db.Select("account_id").Take(&key, GetKeyQueryCondition(s), setupKey)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewSetupKeyNotFoundError(setupKey)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get account by setup key from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get account by setup key from store")
|
|
}
|
|
|
|
if key.AccountID == "" {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
|
|
return s.GetAccount(ctx, key.AccountID)
|
|
}
|
|
|
|
func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) {
|
|
var token types.PersonalAccessToken
|
|
result := s.db.Take(&token, "hashed_token = ?", hashedToken)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
|
return "", status.NewGetAccountFromStoreError(result.Error)
|
|
}
|
|
|
|
return token.ID, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var user types.User
|
|
result := tx.
|
|
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
|
|
Where("personal_access_tokens.id = ?", patID).Take(&user)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewPATNotFoundError(patID)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get token user from the store: %s", result.Error)
|
|
return nil, status.NewGetUserFromStoreError()
|
|
}
|
|
|
|
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return nil, fmt.Errorf("decrypt user: %w", err)
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var user types.User
|
|
result := tx.Take(&user, idQueryCondition, userID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewUserNotFoundError(userID)
|
|
}
|
|
return nil, status.NewGetUserFromStoreError()
|
|
}
|
|
|
|
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return nil, fmt.Errorf("decrypt user: %w", err)
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
func (s *SqlStore) DeleteUser(ctx context.Context, accountID, userID string) error {
|
|
err := s.transaction(func(tx *gorm.DB) error {
|
|
result := tx.Delete(&types.PersonalAccessToken{}, "user_id = ?", userID)
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
return tx.Delete(&types.User{}, accountAndIDQueryCondition, accountID, userID).Error
|
|
})
|
|
if err != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete user from the store: %s", err)
|
|
return status.Errorf(status.Internal, "failed to delete user from store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var users []*types.User
|
|
result := tx.Find(&users, accountIDCondition, accountID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
|
}
|
|
log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "issue getting users from store")
|
|
}
|
|
|
|
for _, user := range users {
|
|
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return nil, fmt.Errorf("decrypt user: %w", err)
|
|
}
|
|
}
|
|
|
|
return users, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var user types.User
|
|
result := tx.Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed")
|
|
}
|
|
return nil, status.Errorf(status.Internal, "failed to get account owner from the store")
|
|
}
|
|
|
|
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return nil, fmt.Errorf("decrypt user: %w", err)
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
// SaveUserInvite saves a user invite to the database
|
|
func (s *SqlStore) SaveUserInvite(ctx context.Context, invite *types.UserInviteRecord) error {
|
|
inviteCopy := invite.Copy()
|
|
if err := inviteCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return fmt.Errorf("encrypt invite: %w", err)
|
|
}
|
|
|
|
result := s.db.Save(inviteCopy)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to save user invite to store: %s", result.Error)
|
|
return status.Errorf(status.Internal, "failed to save user invite to store")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetUserInviteByID retrieves a user invite by its ID and account ID
|
|
func (s *SqlStore) GetUserInviteByID(ctx context.Context, lockStrength LockingStrength, accountID, inviteID string) (*types.UserInviteRecord, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var invite types.UserInviteRecord
|
|
result := tx.Where("account_id = ?", accountID).Take(&invite, idQueryCondition, inviteID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "user invite not found")
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get user invite from store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get user invite from store")
|
|
}
|
|
|
|
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return nil, fmt.Errorf("decrypt invite: %w", err)
|
|
}
|
|
|
|
return &invite, nil
|
|
}
|
|
|
|
// GetUserInviteByHashedToken retrieves a user invite by its hashed token
|
|
func (s *SqlStore) GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var invite types.UserInviteRecord
|
|
result := tx.Take(&invite, "hashed_token = ?", hashedToken)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "user invite not found")
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get user invite from store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get user invite from store")
|
|
}
|
|
|
|
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return nil, fmt.Errorf("decrypt invite: %w", err)
|
|
}
|
|
|
|
return &invite, nil
|
|
}
|
|
|
|
// GetUserInviteByEmail retrieves a user invite by account ID and email.
|
|
// Since email is encrypted with random IVs, we fetch all invites for the account
|
|
// and compare emails in memory after decryption.
|
|
func (s *SqlStore) GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var invites []*types.UserInviteRecord
|
|
result := tx.Find(&invites, "account_id = ?", accountID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get user invites from store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get user invites from store")
|
|
}
|
|
|
|
for _, invite := range invites {
|
|
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return nil, fmt.Errorf("decrypt invite: %w", err)
|
|
}
|
|
if strings.EqualFold(invite.Email, email) {
|
|
return invite, nil
|
|
}
|
|
}
|
|
|
|
return nil, status.Errorf(status.NotFound, "user invite not found for email")
|
|
}
|
|
|
|
// GetAccountUserInvites retrieves all user invites for an account
|
|
func (s *SqlStore) GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var invites []*types.UserInviteRecord
|
|
result := tx.Find(&invites, "account_id = ?", accountID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get user invites from store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get user invites from store")
|
|
}
|
|
|
|
for _, invite := range invites {
|
|
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return nil, fmt.Errorf("decrypt invite: %w", err)
|
|
}
|
|
}
|
|
|
|
return invites, nil
|
|
}
|
|
|
|
// DeleteUserInvite deletes a user invite by its ID
|
|
func (s *SqlStore) DeleteUserInvite(ctx context.Context, inviteID string) error {
|
|
result := s.db.Delete(&types.UserInviteRecord{}, idQueryCondition, inviteID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete user invite from store: %s", result.Error)
|
|
return status.Errorf(status.Internal, "failed to delete user invite from store")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var groups []*types.Group
|
|
result := tx.Preload(clause.Associations).Find(&groups, accountIDCondition, accountID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get account groups from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get account groups from the store")
|
|
}
|
|
|
|
for _, g := range groups {
|
|
g.LoadGroupPeers()
|
|
}
|
|
|
|
return groups, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var groups []*types.Group
|
|
|
|
likePattern := `%"ID":"` + resourceID + `"%`
|
|
|
|
result := tx.
|
|
Preload(clause.Associations).
|
|
Where("resources LIKE ?", likePattern).
|
|
Find(&groups)
|
|
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, result.Error
|
|
}
|
|
|
|
for _, g := range groups {
|
|
g.LoadGroupPeers()
|
|
}
|
|
|
|
return groups, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountsCounter(ctx context.Context) (int64, error) {
|
|
var count int64
|
|
result := s.db.Model(&types.Account{}).Count(&count)
|
|
if result.Error != nil {
|
|
return 0, fmt.Errorf("failed to get all accounts counter: %w", result.Error)
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// GetCustomDomainsCounts returns the total and validated custom domain counts.
|
|
func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
|
|
var total, validated int64
|
|
if err := s.db.Model(&domain.Domain{}).Count(&total).Error; err != nil {
|
|
return 0, 0, err
|
|
}
|
|
if err := s.db.Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil {
|
|
return 0, 0, err
|
|
}
|
|
return total, validated, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
|
|
var accounts []types.Account
|
|
result := s.db.Find(&accounts)
|
|
if result.Error != nil {
|
|
return all
|
|
}
|
|
|
|
for _, account := range accounts {
|
|
if acc, err := s.GetAccount(ctx, account.Id); err == nil {
|
|
all = append(all, acc)
|
|
}
|
|
}
|
|
|
|
return all
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var accountMeta types.AccountMeta
|
|
result := tx.Model(&types.Account{}).
|
|
Take(&accountMeta, idQueryCondition, accountID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("error when getting account meta %s from the store: %s", accountID, result.Error)
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewAccountNotFoundError(accountID)
|
|
}
|
|
return nil, status.NewGetAccountFromStoreError(result.Error)
|
|
}
|
|
|
|
return &accountMeta, nil
|
|
}
|
|
|
|
// GetAccountOnboarding retrieves the onboarding information for a specific account.
|
|
func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error) {
|
|
var accountOnboarding types.AccountOnboarding
|
|
result := s.db.Model(&accountOnboarding).Take(&accountOnboarding, accountIDCondition, accountID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewAccountOnboardingNotFoundError(accountID)
|
|
}
|
|
log.WithContext(ctx).Errorf("error when getting account onboarding %s from the store: %s", accountID, result.Error)
|
|
return nil, status.NewGetAccountFromStoreError(result.Error)
|
|
}
|
|
|
|
return &accountOnboarding, nil
|
|
}
|
|
|
|
// SaveAccountOnboarding updates the onboarding information for a specific account.
|
|
func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error {
|
|
result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(onboarding)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error)
|
|
return status.Errorf(status.Internal, "error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
|
|
if s.pool != nil {
|
|
return s.getAccountPgx(ctx, accountID)
|
|
}
|
|
return s.getAccountGorm(ctx, accountID)
|
|
}
|
|
|
|
func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types.Account, error) {
|
|
start := time.Now()
|
|
defer func() {
|
|
elapsed := time.Since(start)
|
|
if elapsed > 1*time.Second {
|
|
log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed)
|
|
}
|
|
}()
|
|
|
|
var account types.Account
|
|
result := s.db.Model(&account).
|
|
Preload("UsersG.PATsG"). // have to be specified as this is nested reference
|
|
Preload("Policies.Rules").
|
|
Preload("SetupKeysG").
|
|
Preload("PeersG").
|
|
Preload("UsersG").
|
|
Preload("GroupsG.GroupPeers").
|
|
Preload("RoutesG").
|
|
Preload("NameServerGroupsG").
|
|
Preload("PostureChecks").
|
|
Preload("Networks").
|
|
Preload("NetworkRouters").
|
|
Preload("NetworkResources").
|
|
Preload("Onboarding").
|
|
Preload("Services.Targets").
|
|
Take(&account, idQueryCondition, accountID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewAccountNotFoundError(accountID)
|
|
}
|
|
return nil, status.NewGetAccountFromStoreError(result.Error)
|
|
}
|
|
|
|
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
|
|
for _, key := range account.SetupKeysG {
|
|
if key.UpdatedAt.IsZero() {
|
|
key.UpdatedAt = key.CreatedAt
|
|
}
|
|
if key.AutoGroups == nil {
|
|
key.AutoGroups = []string{}
|
|
}
|
|
account.SetupKeys[key.Key] = &key
|
|
}
|
|
account.SetupKeysG = nil
|
|
|
|
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
|
|
for _, peer := range account.PeersG {
|
|
account.Peers[peer.ID] = &peer
|
|
}
|
|
account.PeersG = nil
|
|
account.Users = make(map[string]*types.User, len(account.UsersG))
|
|
for _, user := range account.UsersG {
|
|
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs))
|
|
for _, pat := range user.PATsG {
|
|
pat.UserID = ""
|
|
user.PATs[pat.ID] = &pat
|
|
}
|
|
if user.AutoGroups == nil {
|
|
user.AutoGroups = []string{}
|
|
}
|
|
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return nil, fmt.Errorf("decrypt user: %w", err)
|
|
}
|
|
account.Users[user.Id] = &user
|
|
user.PATsG = nil
|
|
}
|
|
account.UsersG = nil
|
|
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
|
|
for _, group := range account.GroupsG {
|
|
group.Peers = make([]string, len(group.GroupPeers))
|
|
for i, gp := range group.GroupPeers {
|
|
group.Peers[i] = gp.PeerID
|
|
}
|
|
if group.Resources == nil {
|
|
group.Resources = []types.Resource{}
|
|
}
|
|
account.Groups[group.ID] = group
|
|
}
|
|
account.GroupsG = nil
|
|
|
|
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
|
|
for _, route := range account.RoutesG {
|
|
account.Routes[route.ID] = &route
|
|
}
|
|
account.RoutesG = nil
|
|
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
|
|
for _, ns := range account.NameServerGroupsG {
|
|
ns.AccountID = ""
|
|
if ns.NameServers == nil {
|
|
ns.NameServers = []nbdns.NameServer{}
|
|
}
|
|
if ns.Groups == nil {
|
|
ns.Groups = []string{}
|
|
}
|
|
if ns.Domains == nil {
|
|
ns.Domains = []string{}
|
|
}
|
|
account.NameServerGroups[ns.ID] = &ns
|
|
}
|
|
account.NameServerGroupsG = nil
|
|
account.InitOnce()
|
|
return &account, nil
|
|
}
|
|
|
|
func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.Account, error) {
|
|
account, err := s.getAccount(ctx, accountID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
errChan := make(chan error, 12)
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
keys, err := s.getSetupKeys(ctx, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.SetupKeysG = keys
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
peers, err := s.getPeers(ctx, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.PeersG = peers
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
users, err := s.getUsers(ctx, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.UsersG = users
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
groups, err := s.getGroups(ctx, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.GroupsG = groups
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
policies, err := s.getPolicies(ctx, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.Policies = policies
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
routes, err := s.getRoutes(ctx, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.RoutesG = routes
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
nsgs, err := s.getNameServerGroups(ctx, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.NameServerGroupsG = nsgs
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
checks, err := s.getPostureChecks(ctx, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.PostureChecks = checks
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
services, err := s.getServices(ctx, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.Services = services
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
networks, err := s.getNetworks(ctx, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.Networks = networks
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
routers, err := s.getNetworkRouters(ctx, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.NetworkRouters = routers
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
resources, err := s.getNetworkResources(ctx, accountID)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
account.NetworkResources = resources
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
err := s.getAccountOnboarding(ctx, accountID, account)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
}()
|
|
|
|
wg.Wait()
|
|
close(errChan)
|
|
for e := range errChan {
|
|
if e != nil {
|
|
return nil, e
|
|
}
|
|
}
|
|
|
|
var userIDs []string
|
|
for _, u := range account.UsersG {
|
|
userIDs = append(userIDs, u.Id)
|
|
}
|
|
var policyIDs []string
|
|
for _, p := range account.Policies {
|
|
policyIDs = append(policyIDs, p.ID)
|
|
}
|
|
var groupIDs []string
|
|
for _, g := range account.GroupsG {
|
|
groupIDs = append(groupIDs, g.ID)
|
|
}
|
|
|
|
wg.Add(3)
|
|
errChan = make(chan error, 3)
|
|
|
|
var pats []types.PersonalAccessToken
|
|
go func() {
|
|
defer wg.Done()
|
|
var err error
|
|
pats, err = s.getPersonalAccessTokens(ctx, userIDs)
|
|
if err != nil {
|
|
errChan <- err
|
|
}
|
|
}()
|
|
|
|
var rules []*types.PolicyRule
|
|
go func() {
|
|
defer wg.Done()
|
|
var err error
|
|
rules, err = s.getPolicyRules(ctx, policyIDs)
|
|
if err != nil {
|
|
errChan <- err
|
|
}
|
|
}()
|
|
|
|
var groupPeers []types.GroupPeer
|
|
go func() {
|
|
defer wg.Done()
|
|
var err error
|
|
groupPeers, err = s.getGroupPeers(ctx, groupIDs)
|
|
if err != nil {
|
|
errChan <- err
|
|
}
|
|
}()
|
|
|
|
wg.Wait()
|
|
close(errChan)
|
|
for e := range errChan {
|
|
if e != nil {
|
|
return nil, e
|
|
}
|
|
}
|
|
|
|
patsByUserID := make(map[string][]*types.PersonalAccessToken)
|
|
for i := range pats {
|
|
pat := &pats[i]
|
|
patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat)
|
|
pat.UserID = ""
|
|
}
|
|
|
|
rulesByPolicyID := make(map[string][]*types.PolicyRule)
|
|
for _, rule := range rules {
|
|
rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule)
|
|
}
|
|
|
|
peersByGroupID := make(map[string][]string)
|
|
for _, gp := range groupPeers {
|
|
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
|
|
}
|
|
|
|
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
|
|
for i := range account.SetupKeysG {
|
|
key := &account.SetupKeysG[i]
|
|
account.SetupKeys[key.Key] = key
|
|
}
|
|
|
|
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
|
|
for i := range account.PeersG {
|
|
peer := &account.PeersG[i]
|
|
account.Peers[peer.ID] = peer
|
|
}
|
|
|
|
account.Users = make(map[string]*types.User, len(account.UsersG))
|
|
for i := range account.UsersG {
|
|
user := &account.UsersG[i]
|
|
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return nil, fmt.Errorf("decrypt user: %w", err)
|
|
}
|
|
user.PATs = make(map[string]*types.PersonalAccessToken)
|
|
if userPats, ok := patsByUserID[user.Id]; ok {
|
|
for j := range userPats {
|
|
pat := userPats[j]
|
|
user.PATs[pat.ID] = pat
|
|
}
|
|
}
|
|
account.Users[user.Id] = user
|
|
}
|
|
|
|
for i := range account.Policies {
|
|
policy := account.Policies[i]
|
|
if policyRules, ok := rulesByPolicyID[policy.ID]; ok {
|
|
policy.Rules = policyRules
|
|
}
|
|
}
|
|
|
|
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
|
|
for i := range account.GroupsG {
|
|
group := account.GroupsG[i]
|
|
if peerIDs, ok := peersByGroupID[group.ID]; ok {
|
|
group.Peers = peerIDs
|
|
}
|
|
account.Groups[group.ID] = group
|
|
}
|
|
|
|
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
|
|
for i := range account.RoutesG {
|
|
route := &account.RoutesG[i]
|
|
account.Routes[route.ID] = route
|
|
}
|
|
|
|
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
|
|
for i := range account.NameServerGroupsG {
|
|
nsg := &account.NameServerGroupsG[i]
|
|
nsg.AccountID = ""
|
|
account.NameServerGroups[nsg.ID] = nsg
|
|
}
|
|
|
|
account.SetupKeysG = nil
|
|
account.PeersG = nil
|
|
account.UsersG = nil
|
|
account.GroupsG = nil
|
|
account.RoutesG = nil
|
|
account.NameServerGroupsG = nil
|
|
|
|
return account, nil
|
|
}
|
|
|
|
func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Account, error) {
|
|
var account types.Account
|
|
account.Network = &types.Network{}
|
|
const accountQuery = `
|
|
SELECT
|
|
id, created_by, created_at, domain, domain_category, is_domain_primary_account,
|
|
-- Embedded Network
|
|
network_identifier, network_net, network_dns, network_serial,
|
|
-- Embedded DNSSettings
|
|
dns_settings_disabled_management_groups,
|
|
-- Embedded Settings
|
|
settings_peer_login_expiration_enabled, settings_peer_login_expiration,
|
|
settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration,
|
|
settings_regular_users_view_blocked, settings_groups_propagation_enabled,
|
|
settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups,
|
|
settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range,
|
|
settings_lazy_connection_enabled,
|
|
-- Embedded ExtraSettings
|
|
settings_extra_peer_approval_enabled, settings_extra_user_approval_required,
|
|
settings_extra_integrated_validator, settings_extra_integrated_validator_groups
|
|
FROM accounts WHERE id = $1`
|
|
|
|
var (
|
|
sPeerLoginExpirationEnabled sql.NullBool
|
|
sPeerLoginExpiration sql.NullInt64
|
|
sPeerInactivityExpirationEnabled sql.NullBool
|
|
sPeerInactivityExpiration sql.NullInt64
|
|
sRegularUsersViewBlocked sql.NullBool
|
|
sGroupsPropagationEnabled sql.NullBool
|
|
sJWTGroupsEnabled sql.NullBool
|
|
sJWTGroupsClaimName sql.NullString
|
|
sJWTAllowGroups sql.NullString
|
|
sRoutingPeerDNSResolutionEnabled sql.NullBool
|
|
sDNSDomain sql.NullString
|
|
sNetworkRange sql.NullString
|
|
sLazyConnectionEnabled sql.NullBool
|
|
sExtraPeerApprovalEnabled sql.NullBool
|
|
sExtraUserApprovalRequired sql.NullBool
|
|
sExtraIntegratedValidator sql.NullString
|
|
sExtraIntegratedValidatorGroups sql.NullString
|
|
networkNet sql.NullString
|
|
dnsSettingsDisabledGroups sql.NullString
|
|
networkIdentifier sql.NullString
|
|
networkDns sql.NullString
|
|
networkSerial sql.NullInt64
|
|
createdAt sql.NullTime
|
|
)
|
|
err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan(
|
|
&account.Id, &account.CreatedBy, &createdAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount,
|
|
&networkIdentifier, &networkNet, &networkDns, &networkSerial,
|
|
&dnsSettingsDisabledGroups,
|
|
&sPeerLoginExpirationEnabled, &sPeerLoginExpiration,
|
|
&sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration,
|
|
&sRegularUsersViewBlocked, &sGroupsPropagationEnabled,
|
|
&sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups,
|
|
&sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange,
|
|
&sLazyConnectionEnabled,
|
|
&sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired,
|
|
&sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups,
|
|
)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, status.NewAccountNotFoundError(accountID)
|
|
}
|
|
return nil, status.NewGetAccountFromStoreError(err)
|
|
}
|
|
|
|
account.Settings = &types.Settings{Extra: &types.ExtraSettings{}}
|
|
if networkNet.Valid {
|
|
_ = json.Unmarshal([]byte(networkNet.String), &account.Network.Net)
|
|
}
|
|
if createdAt.Valid {
|
|
account.CreatedAt = createdAt.Time
|
|
}
|
|
if dnsSettingsDisabledGroups.Valid {
|
|
_ = json.Unmarshal([]byte(dnsSettingsDisabledGroups.String), &account.DNSSettings.DisabledManagementGroups)
|
|
}
|
|
if networkIdentifier.Valid {
|
|
account.Network.Identifier = networkIdentifier.String
|
|
}
|
|
if networkDns.Valid {
|
|
account.Network.Dns = networkDns.String
|
|
}
|
|
if networkSerial.Valid {
|
|
account.Network.Serial = uint64(networkSerial.Int64)
|
|
}
|
|
if sPeerLoginExpirationEnabled.Valid {
|
|
account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool
|
|
}
|
|
if sPeerLoginExpiration.Valid {
|
|
account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64)
|
|
}
|
|
if sPeerInactivityExpirationEnabled.Valid {
|
|
account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool
|
|
}
|
|
if sPeerInactivityExpiration.Valid {
|
|
account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64)
|
|
}
|
|
if sRegularUsersViewBlocked.Valid {
|
|
account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool
|
|
}
|
|
if sGroupsPropagationEnabled.Valid {
|
|
account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool
|
|
}
|
|
if sJWTGroupsEnabled.Valid {
|
|
account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool
|
|
}
|
|
if sJWTGroupsClaimName.Valid {
|
|
account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String
|
|
}
|
|
if sRoutingPeerDNSResolutionEnabled.Valid {
|
|
account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool
|
|
}
|
|
if sDNSDomain.Valid {
|
|
account.Settings.DNSDomain = sDNSDomain.String
|
|
}
|
|
if sLazyConnectionEnabled.Valid {
|
|
account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool
|
|
}
|
|
if sJWTAllowGroups.Valid {
|
|
_ = json.Unmarshal([]byte(sJWTAllowGroups.String), &account.Settings.JWTAllowGroups)
|
|
}
|
|
if sNetworkRange.Valid {
|
|
_ = json.Unmarshal([]byte(sNetworkRange.String), &account.Settings.NetworkRange)
|
|
}
|
|
|
|
if sExtraPeerApprovalEnabled.Valid {
|
|
account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool
|
|
}
|
|
if sExtraUserApprovalRequired.Valid {
|
|
account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool
|
|
}
|
|
if sExtraIntegratedValidator.Valid {
|
|
account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String
|
|
}
|
|
if sExtraIntegratedValidatorGroups.Valid {
|
|
_ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups)
|
|
}
|
|
account.InitOnce()
|
|
return &account, nil
|
|
}
|
|
|
|
func (s *SqlStore) getSetupKeys(ctx context.Context, accountID string) ([]types.SetupKey, error) {
|
|
const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at,
|
|
revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) {
|
|
var sk types.SetupKey
|
|
var autoGroups []byte
|
|
var skCreatedAt, expiresAt, updatedAt, lastUsed sql.NullTime
|
|
var revoked, ephemeral, allowExtraDNSLabels sql.NullBool
|
|
var usedTimes, usageLimit sql.NullInt64
|
|
|
|
err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &skCreatedAt,
|
|
&expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels)
|
|
|
|
if err == nil {
|
|
if expiresAt.Valid {
|
|
sk.ExpiresAt = &expiresAt.Time
|
|
}
|
|
if skCreatedAt.Valid {
|
|
sk.CreatedAt = skCreatedAt.Time
|
|
}
|
|
if updatedAt.Valid {
|
|
sk.UpdatedAt = updatedAt.Time
|
|
if sk.UpdatedAt.IsZero() {
|
|
sk.UpdatedAt = sk.CreatedAt
|
|
}
|
|
}
|
|
if lastUsed.Valid {
|
|
sk.LastUsed = &lastUsed.Time
|
|
}
|
|
if revoked.Valid {
|
|
sk.Revoked = revoked.Bool
|
|
}
|
|
if usedTimes.Valid {
|
|
sk.UsedTimes = int(usedTimes.Int64)
|
|
}
|
|
if usageLimit.Valid {
|
|
sk.UsageLimit = int(usageLimit.Int64)
|
|
}
|
|
if ephemeral.Valid {
|
|
sk.Ephemeral = ephemeral.Bool
|
|
}
|
|
if allowExtraDNSLabels.Valid {
|
|
sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool
|
|
}
|
|
if autoGroups != nil {
|
|
_ = json.Unmarshal(autoGroups, &sk.AutoGroups)
|
|
} else {
|
|
sk.AutoGroups = []string{}
|
|
}
|
|
}
|
|
return sk, err
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return keys, nil
|
|
}
|
|
|
|
func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Peer, error) {
|
|
const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled,
|
|
inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname,
|
|
meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version,
|
|
meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer,
|
|
meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired,
|
|
peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name,
|
|
location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster FROM peers WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) {
|
|
var p nbpeer.Peer
|
|
p.Status = &nbpeer.PeerStatus{}
|
|
var (
|
|
lastLogin, createdAt sql.NullTime
|
|
sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool
|
|
peerStatusLastSeen sql.NullTime
|
|
peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval, proxyEmbedded sql.NullBool
|
|
ip, extraDNS, netAddr, env, flags, files, connIP []byte
|
|
metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString
|
|
metaOS, metaOSVersion, metaWtVersion, metaUIVersion, metaKernelVersion sql.NullString
|
|
metaSystemSerialNumber, metaSystemProductName, metaSystemManufacturer sql.NullString
|
|
locationCountryCode, locationCityName, proxyCluster sql.NullString
|
|
locationGeoNameID sql.NullInt64
|
|
)
|
|
|
|
err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled,
|
|
&loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &createdAt, &ephemeral, &extraDNS,
|
|
&allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform,
|
|
&metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr,
|
|
&metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files,
|
|
&peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP,
|
|
&locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded, &proxyCluster)
|
|
|
|
if err == nil {
|
|
if lastLogin.Valid {
|
|
p.LastLogin = &lastLogin.Time
|
|
}
|
|
if createdAt.Valid {
|
|
p.CreatedAt = createdAt.Time
|
|
}
|
|
if sshEnabled.Valid {
|
|
p.SSHEnabled = sshEnabled.Bool
|
|
}
|
|
if loginExpirationEnabled.Valid {
|
|
p.LoginExpirationEnabled = loginExpirationEnabled.Bool
|
|
}
|
|
if inactivityExpirationEnabled.Valid {
|
|
p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool
|
|
}
|
|
if ephemeral.Valid {
|
|
p.Ephemeral = ephemeral.Bool
|
|
}
|
|
if allowExtraDNSLabels.Valid {
|
|
p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool
|
|
}
|
|
if peerStatusLastSeen.Valid {
|
|
p.Status.LastSeen = peerStatusLastSeen.Time
|
|
}
|
|
if peerStatusConnected.Valid {
|
|
p.Status.Connected = peerStatusConnected.Bool
|
|
}
|
|
if peerStatusLoginExpired.Valid {
|
|
p.Status.LoginExpired = peerStatusLoginExpired.Bool
|
|
}
|
|
if peerStatusRequiresApproval.Valid {
|
|
p.Status.RequiresApproval = peerStatusRequiresApproval.Bool
|
|
}
|
|
if metaHostname.Valid {
|
|
p.Meta.Hostname = metaHostname.String
|
|
}
|
|
if metaGoOS.Valid {
|
|
p.Meta.GoOS = metaGoOS.String
|
|
}
|
|
if metaKernel.Valid {
|
|
p.Meta.Kernel = metaKernel.String
|
|
}
|
|
if metaCore.Valid {
|
|
p.Meta.Core = metaCore.String
|
|
}
|
|
if metaPlatform.Valid {
|
|
p.Meta.Platform = metaPlatform.String
|
|
}
|
|
if metaOS.Valid {
|
|
p.Meta.OS = metaOS.String
|
|
}
|
|
if metaOSVersion.Valid {
|
|
p.Meta.OSVersion = metaOSVersion.String
|
|
}
|
|
if metaWtVersion.Valid {
|
|
p.Meta.WtVersion = metaWtVersion.String
|
|
}
|
|
if metaUIVersion.Valid {
|
|
p.Meta.UIVersion = metaUIVersion.String
|
|
}
|
|
if metaKernelVersion.Valid {
|
|
p.Meta.KernelVersion = metaKernelVersion.String
|
|
}
|
|
if metaSystemSerialNumber.Valid {
|
|
p.Meta.SystemSerialNumber = metaSystemSerialNumber.String
|
|
}
|
|
if metaSystemProductName.Valid {
|
|
p.Meta.SystemProductName = metaSystemProductName.String
|
|
}
|
|
if metaSystemManufacturer.Valid {
|
|
p.Meta.SystemManufacturer = metaSystemManufacturer.String
|
|
}
|
|
if locationCountryCode.Valid {
|
|
p.Location.CountryCode = locationCountryCode.String
|
|
}
|
|
if locationCityName.Valid {
|
|
p.Location.CityName = locationCityName.String
|
|
}
|
|
if locationGeoNameID.Valid {
|
|
p.Location.GeoNameID = uint(locationGeoNameID.Int64)
|
|
}
|
|
if proxyEmbedded.Valid {
|
|
p.ProxyMeta.Embedded = proxyEmbedded.Bool
|
|
}
|
|
if proxyCluster.Valid {
|
|
p.ProxyMeta.Cluster = proxyCluster.String
|
|
}
|
|
if ip != nil {
|
|
_ = json.Unmarshal(ip, &p.IP)
|
|
}
|
|
if extraDNS != nil {
|
|
_ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels)
|
|
}
|
|
if netAddr != nil {
|
|
_ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses)
|
|
}
|
|
if env != nil {
|
|
_ = json.Unmarshal(env, &p.Meta.Environment)
|
|
}
|
|
if flags != nil {
|
|
_ = json.Unmarshal(flags, &p.Meta.Flags)
|
|
}
|
|
if files != nil {
|
|
_ = json.Unmarshal(files, &p.Meta.Files)
|
|
}
|
|
if connIP != nil {
|
|
_ = json.Unmarshal(connIP, &p.Location.ConnectionIP)
|
|
}
|
|
}
|
|
return p, err
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return peers, nil
|
|
}
|
|
|
|
func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) {
|
|
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) {
|
|
var u types.User
|
|
var autoGroups []byte
|
|
var lastLogin, createdAt sql.NullTime
|
|
var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
|
|
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
|
|
if err == nil {
|
|
if lastLogin.Valid {
|
|
u.LastLogin = &lastLogin.Time
|
|
}
|
|
if createdAt.Valid {
|
|
u.CreatedAt = createdAt.Time
|
|
}
|
|
if isServiceUser.Valid {
|
|
u.IsServiceUser = isServiceUser.Bool
|
|
}
|
|
if nonDeletable.Valid {
|
|
u.NonDeletable = nonDeletable.Bool
|
|
}
|
|
if blocked.Valid {
|
|
u.Blocked = blocked.Bool
|
|
}
|
|
if pendingApproval.Valid {
|
|
u.PendingApproval = pendingApproval.Bool
|
|
}
|
|
if autoGroups != nil {
|
|
_ = json.Unmarshal(autoGroups, &u.AutoGroups)
|
|
} else {
|
|
u.AutoGroups = []string{}
|
|
}
|
|
}
|
|
return u, err
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) {
|
|
const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) {
|
|
var g types.Group
|
|
var resources []byte
|
|
var refID sql.NullInt64
|
|
var refType sql.NullString
|
|
err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType)
|
|
if err == nil {
|
|
if refID.Valid {
|
|
g.IntegrationReference.ID = int(refID.Int64)
|
|
}
|
|
if refType.Valid {
|
|
g.IntegrationReference.IntegrationType = refType.String
|
|
}
|
|
if resources != nil {
|
|
_ = json.Unmarshal(resources, &g.Resources)
|
|
} else {
|
|
g.Resources = []types.Resource{}
|
|
}
|
|
g.GroupPeers = []types.GroupPeer{}
|
|
g.Peers = []string{}
|
|
}
|
|
return &g, err
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return groups, nil
|
|
}
|
|
|
|
func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) {
|
|
const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) {
|
|
var p types.Policy
|
|
var checks []byte
|
|
var enabled sql.NullBool
|
|
err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks)
|
|
if err == nil {
|
|
if enabled.Valid {
|
|
p.Enabled = enabled.Bool
|
|
}
|
|
if checks != nil {
|
|
_ = json.Unmarshal(checks, &p.SourcePostureChecks)
|
|
}
|
|
}
|
|
return &p, err
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return policies, nil
|
|
}
|
|
|
|
func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) {
|
|
const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) {
|
|
var r route.Route
|
|
var network, domains, peerGroups, groups, accessGroups []byte
|
|
var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool
|
|
var metric sql.NullInt64
|
|
err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply)
|
|
if err == nil {
|
|
if keepRoute.Valid {
|
|
r.KeepRoute = keepRoute.Bool
|
|
}
|
|
if masquerade.Valid {
|
|
r.Masquerade = masquerade.Bool
|
|
}
|
|
if enabled.Valid {
|
|
r.Enabled = enabled.Bool
|
|
}
|
|
if skipAutoApply.Valid {
|
|
r.SkipAutoApply = skipAutoApply.Bool
|
|
}
|
|
if metric.Valid {
|
|
r.Metric = int(metric.Int64)
|
|
}
|
|
if network != nil {
|
|
_ = json.Unmarshal(network, &r.Network)
|
|
}
|
|
if domains != nil {
|
|
_ = json.Unmarshal(domains, &r.Domains)
|
|
}
|
|
if peerGroups != nil {
|
|
_ = json.Unmarshal(peerGroups, &r.PeerGroups)
|
|
}
|
|
if groups != nil {
|
|
_ = json.Unmarshal(groups, &r.Groups)
|
|
}
|
|
if accessGroups != nil {
|
|
_ = json.Unmarshal(accessGroups, &r.AccessControlGroups)
|
|
}
|
|
}
|
|
return r, err
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return routes, nil
|
|
}
|
|
|
|
func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) {
|
|
const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) {
|
|
var n nbdns.NameServerGroup
|
|
var ns, groups, domains []byte
|
|
var primary, enabled, searchDomainsEnabled sql.NullBool
|
|
err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled)
|
|
if err == nil {
|
|
if primary.Valid {
|
|
n.Primary = primary.Bool
|
|
}
|
|
if enabled.Valid {
|
|
n.Enabled = enabled.Bool
|
|
}
|
|
if searchDomainsEnabled.Valid {
|
|
n.SearchDomainsEnabled = searchDomainsEnabled.Bool
|
|
}
|
|
if ns != nil {
|
|
_ = json.Unmarshal(ns, &n.NameServers)
|
|
} else {
|
|
n.NameServers = []nbdns.NameServer{}
|
|
}
|
|
if groups != nil {
|
|
_ = json.Unmarshal(groups, &n.Groups)
|
|
} else {
|
|
n.Groups = []string{}
|
|
}
|
|
if domains != nil {
|
|
_ = json.Unmarshal(domains, &n.Domains)
|
|
} else {
|
|
n.Domains = []string{}
|
|
}
|
|
}
|
|
return n, err
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return nsgs, nil
|
|
}
|
|
|
|
func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
|
|
const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) {
|
|
var c posture.Checks
|
|
var checksDef []byte
|
|
err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef)
|
|
if err == nil && checksDef != nil {
|
|
_ = json.Unmarshal(checksDef, &c.Checks)
|
|
}
|
|
return &c, err
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return checks, nil
|
|
}
|
|
|
|
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
|
|
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
|
|
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
|
|
pass_host_header, rewrite_redirects, session_private_key, session_public_key,
|
|
mode, listen_port, port_auto_assigned, source, source_peer, terminated
|
|
FROM services WHERE account_id = $1`
|
|
|
|
const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol,
|
|
target_id, target_type, enabled
|
|
FROM targets WHERE service_id = ANY($1)`
|
|
|
|
serviceRows, err := s.pool.Query(ctx, serviceQuery, accountID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) {
|
|
var s rpservice.Service
|
|
var auth []byte
|
|
var createdAt, certIssuedAt sql.NullTime
|
|
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
|
|
var mode, source, sourcePeer sql.NullString
|
|
var terminated, portAutoAssigned sql.NullBool
|
|
var listenPort sql.NullInt64
|
|
err := row.Scan(
|
|
&s.ID,
|
|
&s.AccountID,
|
|
&s.Name,
|
|
&s.Domain,
|
|
&s.Enabled,
|
|
&auth,
|
|
&createdAt,
|
|
&certIssuedAt,
|
|
&status,
|
|
&proxyCluster,
|
|
&s.PassHostHeader,
|
|
&s.RewriteRedirects,
|
|
&sessionPrivateKey,
|
|
&sessionPublicKey,
|
|
&mode,
|
|
&listenPort,
|
|
&portAutoAssigned,
|
|
&source,
|
|
&sourcePeer,
|
|
&terminated,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if auth != nil {
|
|
if err := json.Unmarshal(auth, &s.Auth); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
s.Meta = rpservice.Meta{}
|
|
if createdAt.Valid {
|
|
s.Meta.CreatedAt = createdAt.Time
|
|
}
|
|
if certIssuedAt.Valid {
|
|
t := certIssuedAt.Time
|
|
s.Meta.CertificateIssuedAt = &t
|
|
}
|
|
if status.Valid {
|
|
s.Meta.Status = status.String
|
|
}
|
|
if proxyCluster.Valid {
|
|
s.ProxyCluster = proxyCluster.String
|
|
}
|
|
if sessionPrivateKey.Valid {
|
|
s.SessionPrivateKey = sessionPrivateKey.String
|
|
}
|
|
if sessionPublicKey.Valid {
|
|
s.SessionPublicKey = sessionPublicKey.String
|
|
}
|
|
if mode.Valid {
|
|
s.Mode = mode.String
|
|
}
|
|
if source.Valid {
|
|
s.Source = source.String
|
|
}
|
|
if sourcePeer.Valid {
|
|
s.SourcePeer = sourcePeer.String
|
|
}
|
|
if terminated.Valid {
|
|
s.Terminated = terminated.Bool
|
|
}
|
|
if portAutoAssigned.Valid {
|
|
s.PortAutoAssigned = portAutoAssigned.Bool
|
|
}
|
|
if listenPort.Valid {
|
|
s.ListenPort = uint16(listenPort.Int64)
|
|
}
|
|
s.Targets = []*rpservice.Target{}
|
|
return &s, nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(services) == 0 {
|
|
return services, nil
|
|
}
|
|
|
|
serviceIDs := make([]string, len(services))
|
|
serviceMap := make(map[string]*rpservice.Service)
|
|
for i, s := range services {
|
|
serviceIDs[i] = s.ID
|
|
serviceMap[s.ID] = s
|
|
}
|
|
|
|
targetRows, err := s.pool.Query(ctx, targetsQuery, serviceIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*rpservice.Target, error) {
|
|
var t rpservice.Target
|
|
var path sql.NullString
|
|
err := row.Scan(
|
|
&t.ID,
|
|
&t.AccountID,
|
|
&t.ServiceID,
|
|
&path,
|
|
&t.Host,
|
|
&t.Port,
|
|
&t.Protocol,
|
|
&t.TargetId,
|
|
&t.TargetType,
|
|
&t.Enabled,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if path.Valid {
|
|
t.Path = &path.String
|
|
}
|
|
return &t, nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, target := range targets {
|
|
if service, ok := serviceMap[target.ServiceID]; ok {
|
|
service.Targets = append(service.Targets, target)
|
|
}
|
|
}
|
|
|
|
return services, nil
|
|
}
|
|
|
|
func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) {
|
|
const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result := make([]*networkTypes.Network, len(networks))
|
|
for i := range networks {
|
|
result[i] = &networks[i]
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) {
|
|
const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) {
|
|
var r routerTypes.NetworkRouter
|
|
var peerGroups []byte
|
|
var masquerade, enabled sql.NullBool
|
|
var metric sql.NullInt64
|
|
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled)
|
|
if err == nil {
|
|
if masquerade.Valid {
|
|
r.Masquerade = masquerade.Bool
|
|
}
|
|
if enabled.Valid {
|
|
r.Enabled = enabled.Bool
|
|
}
|
|
if metric.Valid {
|
|
r.Metric = int(metric.Int64)
|
|
}
|
|
if peerGroups != nil {
|
|
_ = json.Unmarshal(peerGroups, &r.PeerGroups)
|
|
}
|
|
}
|
|
return r, err
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result := make([]*routerTypes.NetworkRouter, len(routers))
|
|
for i := range routers {
|
|
result[i] = &routers[i]
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) {
|
|
const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
|
|
rows, err := s.pool.Query(ctx, query, accountID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) {
|
|
var r resourceTypes.NetworkResource
|
|
var prefix []byte
|
|
var enabled sql.NullBool
|
|
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled)
|
|
if err == nil {
|
|
if enabled.Valid {
|
|
r.Enabled = enabled.Bool
|
|
}
|
|
if prefix != nil {
|
|
_ = json.Unmarshal(prefix, &r.Prefix)
|
|
}
|
|
}
|
|
return r, err
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result := make([]*resourceTypes.NetworkResource, len(resources))
|
|
for i := range resources {
|
|
result[i] = &resources[i]
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s *SqlStore) getAccountOnboarding(ctx context.Context, accountID string, account *types.Account) error {
|
|
const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1`
|
|
var onboardingFlowPending, signupFormPending sql.NullBool
|
|
var createdAt, updatedAt sql.NullTime
|
|
err := s.pool.QueryRow(ctx, query, accountID).Scan(
|
|
&account.Onboarding.AccountID,
|
|
&onboardingFlowPending,
|
|
&signupFormPending,
|
|
&createdAt,
|
|
&updatedAt,
|
|
)
|
|
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
|
return err
|
|
}
|
|
if createdAt.Valid {
|
|
account.Onboarding.CreatedAt = createdAt.Time
|
|
}
|
|
if updatedAt.Valid {
|
|
account.Onboarding.UpdatedAt = updatedAt.Time
|
|
}
|
|
if onboardingFlowPending.Valid {
|
|
account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool
|
|
}
|
|
if signupFormPending.Valid {
|
|
account.Onboarding.SignupFormPending = signupFormPending.Bool
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) getPersonalAccessTokens(ctx context.Context, userIDs []string) ([]types.PersonalAccessToken, error) {
|
|
if len(userIDs) == 0 {
|
|
return nil, nil
|
|
}
|
|
const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)`
|
|
rows, err := s.pool.Query(ctx, query, userIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pats, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) {
|
|
var pat types.PersonalAccessToken
|
|
var expirationDate, lastUsed, createdAt sql.NullTime
|
|
err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &createdAt, &lastUsed)
|
|
if err == nil {
|
|
if expirationDate.Valid {
|
|
pat.ExpirationDate = &expirationDate.Time
|
|
}
|
|
if createdAt.Valid {
|
|
pat.CreatedAt = createdAt.Time
|
|
}
|
|
if lastUsed.Valid {
|
|
pat.LastUsed = &lastUsed.Time
|
|
}
|
|
}
|
|
return pat, err
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return pats, nil
|
|
}
|
|
|
|
func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*types.PolicyRule, error) {
|
|
if len(policyIDs) == 0 {
|
|
return nil, nil
|
|
}
|
|
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges, authorized_groups, authorized_user FROM policy_rules WHERE policy_id = ANY($1)`
|
|
rows, err := s.pool.Query(ctx, query, policyIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) {
|
|
var r types.PolicyRule
|
|
var dest, destRes, sources, sourceRes, ports, portRanges, authorizedGroups []byte
|
|
var enabled, bidirectional sql.NullBool
|
|
var authorizedUser sql.NullString
|
|
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges, &authorizedGroups, &authorizedUser)
|
|
if err == nil {
|
|
if enabled.Valid {
|
|
r.Enabled = enabled.Bool
|
|
}
|
|
if bidirectional.Valid {
|
|
r.Bidirectional = bidirectional.Bool
|
|
}
|
|
if dest != nil {
|
|
_ = json.Unmarshal(dest, &r.Destinations)
|
|
}
|
|
if destRes != nil {
|
|
_ = json.Unmarshal(destRes, &r.DestinationResource)
|
|
}
|
|
if sources != nil {
|
|
_ = json.Unmarshal(sources, &r.Sources)
|
|
}
|
|
if sourceRes != nil {
|
|
_ = json.Unmarshal(sourceRes, &r.SourceResource)
|
|
}
|
|
if ports != nil {
|
|
_ = json.Unmarshal(ports, &r.Ports)
|
|
}
|
|
if portRanges != nil {
|
|
_ = json.Unmarshal(portRanges, &r.PortRanges)
|
|
}
|
|
if authorizedGroups != nil {
|
|
_ = json.Unmarshal(authorizedGroups, &r.AuthorizedGroups)
|
|
}
|
|
if authorizedUser.Valid {
|
|
r.AuthorizedUser = authorizedUser.String
|
|
}
|
|
}
|
|
return &r, err
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return rules, nil
|
|
}
|
|
|
|
func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]types.GroupPeer, error) {
|
|
if len(groupIDs) == 0 {
|
|
return nil, nil
|
|
}
|
|
const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)`
|
|
rows, err := s.pool.Query(ctx, query, groupIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
groupPeers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return groupPeers, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) {
|
|
var user types.User
|
|
result := s.db.Select("account_id").Take(&user, idQueryCondition, userID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
return nil, status.NewGetAccountFromStoreError(result.Error)
|
|
}
|
|
|
|
if user.AccountID == "" {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
|
|
return s.GetAccount(ctx, user.AccountID)
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) {
|
|
var peer nbpeer.Peer
|
|
result := s.db.Select("account_id").Take(&peer, idQueryCondition, peerID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
return nil, status.NewGetAccountFromStoreError(result.Error)
|
|
}
|
|
|
|
if peer.AccountID == "" {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
|
|
return s.GetAccount(ctx, peer.AccountID)
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) {
|
|
var peer nbpeer.Peer
|
|
result := s.db.Select("account_id").Take(&peer, GetKeyQueryCondition(s), peerKey)
|
|
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
return nil, status.NewGetAccountFromStoreError(result.Error)
|
|
}
|
|
|
|
if peer.AccountID == "" {
|
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
|
|
return s.GetAccount(ctx, peer.AccountID)
|
|
}
|
|
|
|
func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) {
|
|
var account types.Account
|
|
result := s.db.Select("id").Order("created_at desc").Limit(1).Find(&account)
|
|
if result.Error != nil {
|
|
return "", status.NewGetAccountFromStoreError(result.Error)
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
|
|
return account.Id, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
|
|
var peer nbpeer.Peer
|
|
var accountID string
|
|
result := s.db.Model(&peer).Select("account_id").Where(GetKeyQueryCondition(s), peerKey).Take(&accountID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
return "", status.NewGetAccountFromStoreError(result.Error)
|
|
}
|
|
|
|
return accountID, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var accountID string
|
|
result := tx.Model(&types.User{}).
|
|
Select("account_id").Where(idQueryCondition, userID).Take(&accountID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
return "", status.NewGetAccountFromStoreError(result.Error)
|
|
}
|
|
|
|
return accountID, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var accountID string
|
|
result := tx.Model(&nbpeer.Peer{}).
|
|
Select("account_id").Where(idQueryCondition, peerID).Take(&accountID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return "", status.Errorf(status.NotFound, "peer %s account not found", peerID)
|
|
}
|
|
return "", status.NewGetAccountFromStoreError(result.Error)
|
|
}
|
|
|
|
return accountID, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
|
|
var accountID string
|
|
result := s.db.Model(&types.SetupKey{}).Select("account_id").Where(GetKeyQueryCondition(s), setupKey).Take(&accountID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return "", status.NewSetupKeyNotFoundError(setupKey)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get account ID by setup key from store: %v", result.Error)
|
|
return "", status.Errorf(status.Internal, "failed to get account ID by setup key from store")
|
|
}
|
|
|
|
if accountID == "" {
|
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
}
|
|
|
|
return accountID, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var ipJSONStrings []string
|
|
|
|
// Fetch the IP addresses as JSON strings
|
|
result := tx.Model(&nbpeer.Peer{}).
|
|
Where("account_id = ?", accountID).
|
|
Pluck("ip", &ipJSONStrings)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
|
}
|
|
return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error)
|
|
}
|
|
|
|
// Convert the JSON strings to net.IP objects
|
|
ips := make([]net.IP, len(ipJSONStrings))
|
|
for i, ipJSON := range ipJSONStrings {
|
|
var ip net.IP
|
|
if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
|
|
return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
|
|
}
|
|
ips[i] = ip
|
|
}
|
|
|
|
return ips, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string, dnsLabel string) ([]string, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var labels []string
|
|
result := tx.Model(&nbpeer.Peer{}).
|
|
Where("account_id = ? AND dns_label LIKE ?", accountID, dnsLabel+"%").
|
|
Pluck("dns_label", &labels)
|
|
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
|
}
|
|
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error)
|
|
}
|
|
|
|
return labels, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var accountNetwork types.AccountNetwork
|
|
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewAccountNotFoundError(accountID)
|
|
}
|
|
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
|
|
}
|
|
return accountNetwork.Network, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var peer nbpeer.Peer
|
|
result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey)
|
|
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewPeerNotFoundError(peerKey)
|
|
}
|
|
return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error)
|
|
}
|
|
|
|
return &peer, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var accountSettings types.AccountSettings
|
|
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountSettings).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "settings not found")
|
|
}
|
|
return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err)
|
|
}
|
|
return accountSettings.Settings, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var createdBy string
|
|
result := tx.Model(&types.Account{}).
|
|
Select("created_by").Take(&createdBy, idQueryCondition, accountID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return "", status.NewAccountNotFoundError(accountID)
|
|
}
|
|
return "", status.NewGetAccountFromStoreError(result.Error)
|
|
}
|
|
|
|
return createdBy, nil
|
|
}
|
|
|
|
// SaveUserLastLogin stores the last login time for a user in DB.
|
|
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
|
var user types.User
|
|
result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return status.NewUserNotFoundError(userID)
|
|
}
|
|
return status.NewGetUserFromStoreError()
|
|
}
|
|
|
|
if !lastLogin.IsZero() {
|
|
user.LastLogin = &lastLogin
|
|
return s.db.Save(&user).Error
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
|
definitionJSON, err := json.Marshal(checks)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var postureCheck posture.Checks
|
|
err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).Take(&postureCheck).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &postureCheck, nil
|
|
}
|
|
|
|
// Close closes the underlying DB connection
|
|
func (s *SqlStore) Close(_ context.Context) error {
|
|
sql, err := s.db.DB()
|
|
if err != nil {
|
|
return fmt.Errorf("get db: %w", err)
|
|
}
|
|
return sql.Close()
|
|
}
|
|
|
|
// GetStoreEngine returns underlying store engine
|
|
func (s *SqlStore) GetStoreEngine() types.Engine {
|
|
return s.storeEngine
|
|
}
|
|
|
|
// NewSqliteStore creates a new SQLite store.
|
|
func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) {
|
|
storeFile := storeSqliteFileName
|
|
if envFile, ok := os.LookupEnv("NB_STORE_ENGINE_SQLITE_FILE"); ok && envFile != "" {
|
|
storeFile = envFile
|
|
}
|
|
|
|
// Separate file path from any SQLite URI query parameters (e.g., "store.db?mode=rwc")
|
|
filePath, query, hasQuery := strings.Cut(storeFile, "?")
|
|
|
|
connStr := filePath
|
|
if !filepath.IsAbs(filePath) {
|
|
connStr = filepath.Join(dataDir, filePath)
|
|
}
|
|
|
|
// Append query parameters: user-provided take precedence, otherwise default to cache=shared on non-Windows
|
|
if hasQuery {
|
|
connStr += "?" + query
|
|
} else if runtime.GOOS != "windows" {
|
|
// To avoid `The process cannot access the file because it is being used by another process` on Windows
|
|
connStr += "?cache=shared"
|
|
}
|
|
|
|
db, err := gorm.Open(sqlite.Open(connStr), getGormConfig())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return NewSqlStore(ctx, db, types.SqliteStoreEngine, metrics, skipMigration)
|
|
}
|
|
|
|
// NewPostgresqlStore creates a new Postgres store.
|
|
func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) {
|
|
db, err := gorm.Open(postgres.Open(dsn), getGormConfig())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pool, err := connectToPgDb(context.Background(), dsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration)
|
|
if err != nil {
|
|
pool.Close()
|
|
return nil, err
|
|
}
|
|
store.pool = pool
|
|
return store, nil
|
|
}
|
|
|
|
func connectToPgDb(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
|
|
config, err := pgxpool.ParseConfig(dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to parse database config: %w", err)
|
|
}
|
|
|
|
config.MaxConns = pgMaxConnections
|
|
config.MinConns = pgMinConnections
|
|
config.MaxConnLifetime = pgMaxConnLifetime
|
|
config.HealthCheckPeriod = pgHealthCheckPeriod
|
|
|
|
pool, err := pgxpool.NewWithConfig(ctx, config)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to create connection pool: %w", err)
|
|
}
|
|
|
|
if err := pool.Ping(ctx); err != nil {
|
|
pool.Close()
|
|
return nil, fmt.Errorf("unable to ping database: %w", err)
|
|
}
|
|
|
|
return pool, nil
|
|
}
|
|
|
|
// NewMysqlStore creates a new MySQL store.
|
|
func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) {
|
|
db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), getGormConfig())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return NewSqlStore(ctx, db, types.MysqlStoreEngine, metrics, skipMigration)
|
|
}
|
|
|
|
func getGormConfig() *gorm.Config {
|
|
return &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
CreateBatchSize: 400,
|
|
}
|
|
}
|
|
|
|
// newPostgresStore initializes a new Postgres store.
|
|
func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics, skipMigration bool) (Store, error) {
|
|
dsn, ok := lookupDSNEnv(postgresDsnEnv, postgresDsnEnvLegacy)
|
|
if !ok {
|
|
return nil, fmt.Errorf("%s is not set", postgresDsnEnv)
|
|
}
|
|
return NewPostgresqlStore(ctx, dsn, metrics, skipMigration)
|
|
}
|
|
|
|
// newMysqlStore initializes a new MySQL store.
|
|
func newMysqlStore(ctx context.Context, metrics telemetry.AppMetrics, skipMigration bool) (Store, error) {
|
|
dsn, ok := lookupDSNEnv(mysqlDsnEnv, mysqlDsnEnvLegacy)
|
|
if !ok {
|
|
return nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
|
|
}
|
|
return NewMysqlStore(ctx, dsn, metrics, skipMigration)
|
|
}
|
|
|
|
// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir.
|
|
func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) {
|
|
store, err := NewSqliteStore(ctx, dataDir, metrics, skipMigration)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = store.SaveInstallationID(ctx, fileStore.InstallationID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, account := range fileStore.GetAllAccounts(ctx) {
|
|
_, err = account.GetGroupAll()
|
|
if err != nil {
|
|
if err := account.AddAllGroup(false); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
err := store.SaveAccount(ctx, account)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return store, nil
|
|
}
|
|
|
|
// NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB.
|
|
func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
|
store, err := NewPostgresqlStoreForTests(ctx, dsn, metrics, false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = store.SaveInstallationID(ctx, sqliteStore.GetInstallationID())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, account := range sqliteStore.GetAllAccounts(ctx) {
|
|
err := store.SaveAccount(ctx, account)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return store, nil
|
|
}
|
|
|
|
// used for tests only
|
|
func NewPostgresqlStoreForTests(ctx context.Context, dsn string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) {
|
|
db, err := gorm.Open(postgres.Open(dsn), getGormConfig())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pool, err := connectToPgDbForTests(context.Background(), dsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration)
|
|
if err != nil {
|
|
pool.Close()
|
|
return nil, err
|
|
}
|
|
store.pool = pool
|
|
return store, nil
|
|
}
|
|
|
|
// used for tests only
|
|
func connectToPgDbForTests(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
|
|
config, err := pgxpool.ParseConfig(dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to parse database config: %w", err)
|
|
}
|
|
|
|
config.MaxConns = 5
|
|
config.MinConns = 1
|
|
config.MaxConnLifetime = 30 * time.Second
|
|
config.HealthCheckPeriod = 10 * time.Second
|
|
|
|
pool, err := pgxpool.NewWithConfig(ctx, config)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to create connection pool: %w", err)
|
|
}
|
|
|
|
if err := pool.Ping(ctx); err != nil {
|
|
pool.Close()
|
|
return nil, fmt.Errorf("unable to ping database: %w", err)
|
|
}
|
|
|
|
return pool, nil
|
|
}
|
|
|
|
// NewMysqlStoreFromSqlStore restores a store from SqlStore and stores MySQL DB.
|
|
func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
|
store, err := NewMysqlStore(ctx, dsn, metrics, false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = store.SaveInstallationID(ctx, sqliteStore.GetInstallationID())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, account := range sqliteStore.GetAllAccounts(ctx) {
|
|
err := store.SaveAccount(ctx, account)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return store, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var setupKey types.SetupKey
|
|
result := tx.
|
|
Take(&setupKey, GetKeyQueryCondition(s), key)
|
|
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.PreconditionFailed, "setup key not found")
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store")
|
|
}
|
|
return &setupKey, nil
|
|
}
|
|
|
|
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
|
result := s.db.Model(&types.SetupKey{}).
|
|
Where(idQueryCondition, setupKeyID).
|
|
Updates(map[string]interface{}{
|
|
"used_times": gorm.Expr("used_times + 1"),
|
|
"last_used": time.Now(),
|
|
})
|
|
|
|
if result.Error != nil {
|
|
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewSetupKeyNotFoundError(setupKeyID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
|
|
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
|
var groupID string
|
|
_ = s.db.Model(types.Group{}).
|
|
Select("id").
|
|
Where("account_id = ? AND name = ?", accountID, "All").
|
|
Limit(1).
|
|
Scan(&groupID)
|
|
|
|
if groupID == "" {
|
|
return status.Errorf(status.NotFound, "group 'All' not found for account %s", accountID)
|
|
}
|
|
|
|
err := s.db.Clauses(clause.OnConflict{
|
|
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
|
|
DoNothing: true,
|
|
}).Create(&types.GroupPeer{
|
|
AccountID: accountID,
|
|
GroupID: groupID,
|
|
PeerID: peerID,
|
|
}).Error
|
|
|
|
if err != nil {
|
|
return status.Errorf(status.Internal, "error adding peer to group 'All': %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddPeerToGroup adds a peer to a group
|
|
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
|
|
peer := &types.GroupPeer{
|
|
AccountID: accountID,
|
|
GroupID: groupID,
|
|
PeerID: peerID,
|
|
}
|
|
|
|
err := s.db.Clauses(clause.OnConflict{
|
|
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
|
|
DoNothing: true,
|
|
}).Create(peer).Error
|
|
|
|
if err != nil {
|
|
log.WithContext(ctx).Errorf("failed to add peer %s to group %s for account %s: %v", peerID, groupID, accountID, err)
|
|
return status.Errorf(status.Internal, "failed to add peer to group")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RemovePeerFromGroup removes a peer from a group
|
|
func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error {
|
|
err := s.db.
|
|
Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error
|
|
|
|
if err != nil {
|
|
log.WithContext(ctx).Errorf("failed to remove peer %s from group %s: %v", peerID, groupID, err)
|
|
return status.Errorf(status.Internal, "failed to remove peer from group")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RemovePeerFromAllGroups removes a peer from all groups
|
|
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
|
|
err := s.db.
|
|
Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error
|
|
|
|
if err != nil {
|
|
log.WithContext(ctx).Errorf("failed to remove peer %s from all groups: %v", peerID, err)
|
|
return status.Errorf(status.Internal, "failed to remove peer from all groups")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddResourceToGroup adds a resource to a group. Method always needs to run n a transaction
|
|
func (s *SqlStore) AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error {
|
|
var group types.Group
|
|
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).Take(&group)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return status.NewGroupNotFoundError(groupID)
|
|
}
|
|
|
|
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
|
|
}
|
|
|
|
for _, res := range group.Resources {
|
|
if res.ID == resource.ID {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
group.Resources = append(group.Resources, *resource)
|
|
|
|
if err := s.db.Save(&group).Error; err != nil {
|
|
return status.Errorf(status.Internal, "issue updating group: %s", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RemoveResourceFromGroup removes a resource from a group. Method always needs to run in a transaction
|
|
func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error {
|
|
var group types.Group
|
|
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).Take(&group)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return status.NewGroupNotFoundError(groupID)
|
|
}
|
|
|
|
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
|
|
}
|
|
|
|
for i, res := range group.Resources {
|
|
if res.ID == resourceID {
|
|
group.Resources = append(group.Resources[:i], group.Resources[i+1:]...)
|
|
break
|
|
}
|
|
}
|
|
|
|
if err := s.db.Save(&group).Error; err != nil {
|
|
return status.Errorf(status.Internal, "issue updating group: %s", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetPeerGroups retrieves all groups assigned to a specific peer in a given account.
|
|
func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var groups []*types.Group
|
|
query := tx.
|
|
Joins("JOIN group_peers ON group_peers.group_id = groups.id").
|
|
Where("group_peers.peer_id = ?", peerId).
|
|
Preload(clause.Associations).
|
|
Find(&groups)
|
|
|
|
if query.Error != nil {
|
|
return nil, query.Error
|
|
}
|
|
|
|
for _, group := range groups {
|
|
group.LoadGroupPeers()
|
|
}
|
|
|
|
return groups, nil
|
|
}
|
|
|
|
// GetPeerGroupIDs retrieves all group IDs assigned to a specific peer in a given account.
|
|
func (s *SqlStore) GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var groupIDs []string
|
|
query := tx.
|
|
Model(&types.GroupPeer{}).
|
|
Where("account_id = ? AND peer_id = ?", accountId, peerId).
|
|
Pluck("group_id", &groupIDs)
|
|
|
|
if query.Error != nil {
|
|
if errors.Is(query.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "no groups found for peer %s in account %s", peerId, accountId)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get group IDs for peer %s in account %s: %v", peerId, accountId, query.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get group IDs for peer from store")
|
|
}
|
|
|
|
return groupIDs, nil
|
|
}
|
|
|
|
// GetAccountPeers retrieves peers for an account.
|
|
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
|
var peers []*nbpeer.Peer
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
query := tx.Where(accountIDCondition, accountID)
|
|
|
|
if nameFilter != "" {
|
|
query = query.Where("name LIKE ?", "%"+nameFilter+"%")
|
|
}
|
|
if ipFilter != "" {
|
|
query = query.Where("ip LIKE ?", "%"+ipFilter+"%")
|
|
}
|
|
|
|
if err := query.Find(&peers).Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err)
|
|
return nil, status.Errorf(status.Internal, "failed to get peers from store")
|
|
}
|
|
|
|
return peers, nil
|
|
}
|
|
|
|
// GetUserPeers retrieves peers for a user.
|
|
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var peers []*nbpeer.Peer
|
|
|
|
// Exclude peers added via setup keys, as they are not user-specific and have an empty user_id.
|
|
if userID == "" {
|
|
return peers, nil
|
|
}
|
|
|
|
result := tx.
|
|
Find(&peers, "account_id = ? AND user_id = ?", accountID, userID)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err)
|
|
return nil, status.Errorf(status.Internal, "failed to get peers from store")
|
|
}
|
|
|
|
return peers, nil
|
|
}
|
|
|
|
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
|
if err := s.db.Create(peer).Error; err != nil {
|
|
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetPeerByID retrieves a peer by its ID and account ID.
|
|
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var peer *nbpeer.Peer
|
|
result := tx.
|
|
Take(&peer, accountAndIDQueryCondition, accountID, peerID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewPeerNotFoundError(peerID)
|
|
}
|
|
return nil, status.Errorf(status.Internal, "failed to get peer from store")
|
|
}
|
|
|
|
return peer, nil
|
|
}
|
|
|
|
// GetPeersByIDs retrieves peers by their IDs and account ID.
|
|
func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var peers []*nbpeer.Peer
|
|
result := tx.Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get peers by ID's from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get peers by ID's from the store")
|
|
}
|
|
|
|
peersMap := make(map[string]*nbpeer.Peer)
|
|
for _, peer := range peers {
|
|
peersMap[peer.ID] = peer
|
|
}
|
|
|
|
return peersMap, nil
|
|
}
|
|
|
|
// GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user.
|
|
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var peers []*nbpeer.Peer
|
|
result := tx.
|
|
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
|
|
Find(&peers, accountIDCondition, accountID)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get peers with expiration from store")
|
|
}
|
|
|
|
return peers, nil
|
|
}
|
|
|
|
// GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user.
|
|
func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var peers []*nbpeer.Peer
|
|
result := tx.
|
|
Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
|
|
Find(&peers, accountIDCondition, accountID)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to get peers with inactivity from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get peers with inactivity from store")
|
|
}
|
|
|
|
return peers, nil
|
|
}
|
|
|
|
// GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing.
|
|
func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var allEphemeralPeers, batchPeers []*nbpeer.Peer
|
|
result := tx.
|
|
Where("ephemeral = ?", true).
|
|
FindInBatches(&batchPeers, 1000, func(tx *gorm.DB, batch int) error {
|
|
allEphemeralPeers = append(allEphemeralPeers, batchPeers...)
|
|
return nil
|
|
})
|
|
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to retrieve ephemeral peers: %s", result.Error)
|
|
return nil, fmt.Errorf("failed to retrieve ephemeral peers")
|
|
}
|
|
|
|
return allEphemeralPeers, nil
|
|
}
|
|
|
|
// DeletePeer removes a peer from the store.
|
|
func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID string) error {
|
|
result := s.db.Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete peer from the store: %s", err)
|
|
return status.Errorf(status.Internal, "failed to delete peer from store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewPeerNotFoundError(peerID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
|
result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to increment network serial count in store")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
|
|
timeoutCtx, cancel := context.WithTimeout(context.Background(), s.transactionTimeout)
|
|
defer cancel()
|
|
|
|
startTime := time.Now()
|
|
tx := s.db.WithContext(timeoutCtx).Begin()
|
|
if tx.Error != nil {
|
|
return tx.Error
|
|
}
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
tx.Rollback()
|
|
panic(r)
|
|
}
|
|
}()
|
|
|
|
if s.storeEngine == types.PostgresStoreEngine {
|
|
if err := tx.Exec("SET LOCAL statement_timeout = '1min'").Error; err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("failed to set statement timeout: %w", err)
|
|
}
|
|
if err := tx.Exec("SET LOCAL lock_timeout = '1min'").Error; err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("failed to set lock timeout: %w", err)
|
|
}
|
|
}
|
|
|
|
// For MySQL, disable FK checks within this transaction to avoid deadlocks
|
|
// This is session-scoped and doesn't require SUPER privileges
|
|
if s.storeEngine == types.MysqlStoreEngine {
|
|
if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("failed to disable FK checks: %w", err)
|
|
}
|
|
}
|
|
|
|
repo := s.withTx(tx)
|
|
err := operation(repo)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
if errors.Is(err, context.DeadlineExceeded) || errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) {
|
|
log.WithContext(ctx).Warnf("transaction exceeded %s timeout after %v, stack: %s", s.transactionTimeout, time.Since(startTime), debug.Stack())
|
|
}
|
|
return err
|
|
}
|
|
|
|
// Re-enable FK checks before commit (optional, as transaction end resets it)
|
|
if s.storeEngine == types.MysqlStoreEngine {
|
|
if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 1").Error; err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("failed to re-enable FK checks: %w", err)
|
|
}
|
|
}
|
|
|
|
err = tx.Commit().Error
|
|
if err != nil {
|
|
if errors.Is(err, context.DeadlineExceeded) || errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) {
|
|
log.WithContext(ctx).Warnf("transaction commit exceeded %s timeout after %v, stack: %s", s.transactionTimeout, time.Since(startTime), debug.Stack())
|
|
}
|
|
return err
|
|
}
|
|
|
|
log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime))
|
|
if s.metrics != nil {
|
|
s.metrics.StoreMetrics().CountTransactionDuration(time.Since(startTime))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
|
return &SqlStore{
|
|
db: tx,
|
|
storeEngine: s.storeEngine,
|
|
fieldEncrypt: s.fieldEncrypt,
|
|
}
|
|
}
|
|
|
|
// transaction wraps a GORM transaction with MySQL-specific FK checks handling
|
|
// Use this instead of db.Transaction() directly to avoid deadlocks on MySQL/Aurora
|
|
func (s *SqlStore) transaction(fn func(*gorm.DB) error) error {
|
|
return s.db.Transaction(func(tx *gorm.DB) error {
|
|
// For MySQL, disable FK checks within this transaction to avoid deadlocks
|
|
// This is session-scoped and doesn't require SUPER privileges
|
|
if s.storeEngine == types.MysqlStoreEngine {
|
|
if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil {
|
|
return fmt.Errorf("failed to disable FK checks: %w", err)
|
|
}
|
|
}
|
|
|
|
err := fn(tx)
|
|
|
|
// Re-enable FK checks before commit (optional, as transaction end resets it)
|
|
if s.storeEngine == types.MysqlStoreEngine && err == nil {
|
|
if fkErr := tx.Exec("SET FOREIGN_KEY_CHECKS = 1").Error; fkErr != nil {
|
|
return fmt.Errorf("failed to re-enable FK checks: %w", fkErr)
|
|
}
|
|
}
|
|
|
|
return err
|
|
})
|
|
}
|
|
|
|
func (s *SqlStore) GetDB() *gorm.DB {
|
|
return s.db
|
|
}
|
|
|
|
// SetFieldEncrypt sets the field encryptor for encrypting sensitive user data.
|
|
func (s *SqlStore) SetFieldEncrypt(enc *crypt.FieldEncrypt) {
|
|
s.fieldEncrypt = enc
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var accountDNSSettings types.AccountDNSSettings
|
|
result := tx.Model(&types.Account{}).
|
|
Take(&accountDNSSettings, idQueryCondition, accountID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewAccountNotFoundError(accountID)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get dns settings from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get dns settings from store")
|
|
}
|
|
return &accountDNSSettings.DNSSettings, nil
|
|
}
|
|
|
|
// AccountExists checks whether an account exists by the given ID.
|
|
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var accountID string
|
|
result := tx.Model(&types.Account{}).
|
|
Select("id").Take(&accountID, idQueryCondition, id)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return false, nil
|
|
}
|
|
return false, result.Error
|
|
}
|
|
|
|
return accountID != "", nil
|
|
}
|
|
|
|
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
|
|
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var account types.Account
|
|
result := tx.Model(&types.Account{}).Select("domain", "domain_category").
|
|
Where(idQueryCondition, accountID).Take(&account)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return "", "", status.Errorf(status.NotFound, "account not found")
|
|
}
|
|
return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
|
|
}
|
|
|
|
return account.Domain, account.DomainCategory, nil
|
|
}
|
|
|
|
// GetGroupByID retrieves a group by ID and account ID.
|
|
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var group *types.Group
|
|
result := tx.Preload(clause.Associations).Take(&group, accountAndIDQueryCondition, accountID, groupID)
|
|
if err := result.Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewGroupNotFoundError(groupID)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get group from store: %s", err)
|
|
return nil, status.Errorf(status.Internal, "failed to get group from store")
|
|
}
|
|
|
|
group.LoadGroupPeers()
|
|
|
|
return group, nil
|
|
}
|
|
|
|
// GetGroupByName retrieves a group by name and account ID.
|
|
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) {
|
|
tx := s.db
|
|
|
|
var group types.Group
|
|
|
|
// TODO: This fix is accepted for now, but if we need to handle this more frequently
|
|
// we may need to reconsider changing the types.
|
|
query := tx.Preload(clause.Associations)
|
|
|
|
result := query.
|
|
Model(&types.Group{}).
|
|
Joins("LEFT JOIN group_peers ON group_peers.group_id = groups.id").
|
|
Where("groups.account_id = ? AND groups.name = ?", accountID, groupName).
|
|
Group("groups.id").
|
|
Order("COUNT(group_peers.peer_id) DESC").
|
|
Limit(1).
|
|
First(&group)
|
|
if err := result.Error; err != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewGroupNotFoundError(groupName)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get group by name from store")
|
|
}
|
|
|
|
group.LoadGroupPeers()
|
|
|
|
return &group, nil
|
|
}
|
|
|
|
// GetGroupsByIDs retrieves groups by their IDs and account ID.
|
|
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var groups []*types.Group
|
|
result := tx.Preload(clause.Associations).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
|
|
}
|
|
|
|
groupsMap := make(map[string]*types.Group)
|
|
for _, group := range groups {
|
|
group.LoadGroupPeers()
|
|
groupsMap[group.ID] = group
|
|
}
|
|
|
|
return groupsMap, nil
|
|
}
|
|
|
|
// CreateGroup creates a group in the store.
|
|
func (s *SqlStore) CreateGroup(ctx context.Context, group *types.Group) error {
|
|
if group == nil {
|
|
return status.Errorf(status.InvalidArgument, "group is nil")
|
|
}
|
|
|
|
if err := s.db.Omit(clause.Associations).Create(group).Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
|
|
return status.Errorf(status.Internal, "failed to save group to store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdateGroup updates a group in the store.
|
|
func (s *SqlStore) UpdateGroup(ctx context.Context, group *types.Group) error {
|
|
if group == nil {
|
|
return status.Errorf(status.InvalidArgument, "group is nil")
|
|
}
|
|
|
|
if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
|
|
return status.Errorf(status.Internal, "failed to save group to store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteGroup deletes a group from the database.
|
|
func (s *SqlStore) DeleteGroup(ctx context.Context, accountID, groupID string) error {
|
|
result := s.db.Select(clause.Associations).
|
|
Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
|
|
return status.Errorf(status.Internal, "failed to delete group from store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewGroupNotFoundError(groupID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteGroups deletes groups from the database.
|
|
func (s *SqlStore) DeleteGroups(ctx context.Context, accountID string, groupIDs []string) error {
|
|
result := s.db.Select(clause.Associations).
|
|
Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to delete groups from store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetAccountPolicies retrieves policies for an account.
|
|
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var policies []*types.Policy
|
|
result := tx.
|
|
Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get policies from store")
|
|
}
|
|
|
|
return policies, nil
|
|
}
|
|
|
|
// GetPolicyByID retrieves a policy by its ID and account ID.
|
|
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var policy *types.Policy
|
|
|
|
result := tx.Preload(clause.Associations).
|
|
Take(&policy, accountAndIDQueryCondition, accountID, policyID)
|
|
if err := result.Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewPolicyNotFoundError(policyID)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get policy from store: %s", err)
|
|
return nil, status.Errorf(status.Internal, "failed to get policy from store")
|
|
}
|
|
|
|
return policy, nil
|
|
}
|
|
|
|
func (s *SqlStore) CreatePolicy(ctx context.Context, policy *types.Policy) error {
|
|
result := s.db.Create(policy)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error)
|
|
return status.Errorf(status.Internal, "failed to create policy in store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SavePolicy saves a policy to the database.
|
|
func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error {
|
|
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Save(policy)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
|
|
return status.Errorf(status.Internal, "failed to save policy to store")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) error {
|
|
return s.transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil {
|
|
return fmt.Errorf("delete policy rules: %w", err)
|
|
}
|
|
|
|
result := tx.
|
|
Where(accountAndIDQueryCondition, accountID, policyID).
|
|
Delete(&types.Policy{})
|
|
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err)
|
|
return status.Errorf(status.Internal, "failed to delete policy from store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewPolicyNotFoundError(policyID)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (s *SqlStore) GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, resourceID string) ([]*types.PolicyRule, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var policyRules []*types.PolicyRule
|
|
resourceIDPattern := `%"ID":"` + resourceID + `"%`
|
|
result := tx.Where("source_resource LIKE ? OR destination_resource LIKE ?", resourceIDPattern, resourceIDPattern).
|
|
Find(&policyRules)
|
|
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get policy rules for resource id from store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get policy rules for resource id from store")
|
|
}
|
|
|
|
return policyRules, nil
|
|
}
|
|
|
|
// GetAccountPostureChecks retrieves posture checks for an account.
|
|
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var postureChecks []*posture.Checks
|
|
result := tx.Find(&postureChecks, accountIDCondition, accountID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get posture checks from store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get posture checks from store")
|
|
}
|
|
|
|
return postureChecks, nil
|
|
}
|
|
|
|
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
|
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var postureCheck *posture.Checks
|
|
result := tx.
|
|
Take(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewPostureChecksNotFoundError(postureChecksID)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get posture check from store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get posture check from store")
|
|
}
|
|
|
|
return postureCheck, nil
|
|
}
|
|
|
|
// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID.
|
|
func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var postureChecks []*posture.Checks
|
|
result := tx.Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get posture checks by ID's from store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get posture checks by ID's from store")
|
|
}
|
|
|
|
postureChecksMap := make(map[string]*posture.Checks)
|
|
for _, postureCheck := range postureChecks {
|
|
postureChecksMap[postureCheck.ID] = postureCheck
|
|
}
|
|
|
|
return postureChecksMap, nil
|
|
}
|
|
|
|
// SavePostureChecks saves a posture checks to the database.
|
|
func (s *SqlStore) SavePostureChecks(ctx context.Context, postureCheck *posture.Checks) error {
|
|
result := s.db.Save(postureCheck)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error)
|
|
return status.Errorf(status.Internal, "failed to save posture checks to store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeletePostureChecks deletes a posture checks from the database.
|
|
func (s *SqlStore) DeletePostureChecks(ctx context.Context, accountID, postureChecksID string) error {
|
|
result := s.db.Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error)
|
|
return status.Errorf(status.Internal, "failed to delete posture checks from store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewPostureChecksNotFoundError(postureChecksID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetAccountRoutes retrieves network routes for an account.
|
|
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var routes []*route.Route
|
|
result := tx.Find(&routes, accountIDCondition, accountID)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
|
|
return nil, status.Errorf(status.Internal, "failed to get routes from store")
|
|
}
|
|
|
|
return routes, nil
|
|
}
|
|
|
|
// GetRouteByID retrieves a route by its ID and account ID.
|
|
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var route *route.Route
|
|
result := tx.Take(&route, accountAndIDQueryCondition, accountID, routeID)
|
|
if err := result.Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewRouteNotFoundError(routeID)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get route from the store: %s", err)
|
|
return nil, status.Errorf(status.Internal, "failed to get route from store")
|
|
}
|
|
|
|
return route, nil
|
|
}
|
|
|
|
// SaveRoute saves a route to the database.
|
|
func (s *SqlStore) SaveRoute(ctx context.Context, route *route.Route) error {
|
|
result := s.db.Save(route)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to save route to the store: %s", err)
|
|
return status.Errorf(status.Internal, "failed to save route to store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteRoute deletes a route from the database.
|
|
func (s *SqlStore) DeleteRoute(ctx context.Context, accountID, routeID string) error {
|
|
result := s.db.Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err)
|
|
return status.Errorf(status.Internal, "failed to delete route from store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewRouteNotFoundError(routeID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetAccountSetupKeys retrieves setup keys for an account.
|
|
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var setupKeys []*types.SetupKey
|
|
result := tx.
|
|
Find(&setupKeys, accountIDCondition, accountID)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err)
|
|
return nil, status.Errorf(status.Internal, "failed to get setup keys from store")
|
|
}
|
|
|
|
return setupKeys, nil
|
|
}
|
|
|
|
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
|
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var setupKey *types.SetupKey
|
|
result := tx.Take(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
|
|
if err := result.Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewSetupKeyNotFoundError(setupKeyID)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err)
|
|
return nil, status.Errorf(status.Internal, "failed to get setup key from store")
|
|
}
|
|
|
|
return setupKey, nil
|
|
}
|
|
|
|
// SaveSetupKey saves a setup key to the database.
|
|
func (s *SqlStore) SaveSetupKey(ctx context.Context, setupKey *types.SetupKey) error {
|
|
result := s.db.Save(setupKey)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error)
|
|
return status.Errorf(status.Internal, "failed to save setup key to store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteSetupKey deletes a setup key from the database.
|
|
func (s *SqlStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error {
|
|
result := s.db.Delete(&types.SetupKey{}, accountAndIDQueryCondition, accountID, keyID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error)
|
|
return status.Errorf(status.Internal, "failed to delete setup key from store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewSetupKeyNotFoundError(keyID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetAccountNameServerGroups retrieves name server groups for an account.
|
|
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var nsGroups []*nbdns.NameServerGroup
|
|
result := tx.Find(&nsGroups, accountIDCondition, accountID)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err)
|
|
return nil, status.Errorf(status.Internal, "failed to get name server groups from store")
|
|
}
|
|
|
|
return nsGroups, nil
|
|
}
|
|
|
|
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
|
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var nsGroup *nbdns.NameServerGroup
|
|
result := tx.
|
|
Take(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
|
|
if err := result.Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewNameServerGroupNotFoundError(nsGroupID)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get name server group from the store: %s", err)
|
|
return nil, status.Errorf(status.Internal, "failed to get name server group from store")
|
|
}
|
|
|
|
return nsGroup, nil
|
|
}
|
|
|
|
// SaveNameServerGroup saves a name server group to the database.
|
|
func (s *SqlStore) SaveNameServerGroup(ctx context.Context, nameServerGroup *nbdns.NameServerGroup) error {
|
|
result := s.db.Save(nameServerGroup)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
|
|
return status.Errorf(status.Internal, "failed to save name server group to store")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteNameServerGroup deletes a name server group from the database.
|
|
func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID string) error {
|
|
result := s.db.Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nsGroupID)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err)
|
|
return status.Errorf(status.Internal, "failed to delete name server group from store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewNameServerGroupNotFoundError(nsGroupID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SaveDNSSettings saves the DNS settings to the store.
|
|
func (s *SqlStore) SaveDNSSettings(ctx context.Context, accountID string, settings *types.DNSSettings) error {
|
|
result := s.db.Model(&types.Account{}).
|
|
Where(idQueryCondition, accountID).Updates(&types.AccountDNSSettings{DNSSettings: *settings})
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to save dns settings to store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewAccountNotFoundError(accountID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SaveAccountSettings stores the account settings in DB.
|
|
func (s *SqlStore) SaveAccountSettings(ctx context.Context, accountID string, settings *types.Settings) error {
|
|
result := s.db.Model(&types.Account{}).
|
|
Select("*").Where(idQueryCondition, accountID).Updates(&types.AccountSettings{Settings: settings})
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to save account settings to store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to save account settings to store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewAccountNotFoundError(accountID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var networks []*networkTypes.Network
|
|
result := tx.Find(&networks, accountIDCondition, accountID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get networks from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get networks from store")
|
|
}
|
|
|
|
return networks, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var network *networkTypes.Network
|
|
result := tx.Take(&network, accountAndIDQueryCondition, accountID, networkID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewNetworkNotFoundError(networkID)
|
|
}
|
|
|
|
log.WithContext(ctx).Errorf("failed to get network from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get network from store")
|
|
}
|
|
|
|
return network, nil
|
|
}
|
|
|
|
func (s *SqlStore) SaveNetwork(ctx context.Context, network *networkTypes.Network) error {
|
|
result := s.db.Save(network)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to save network to store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to save network to store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) DeleteNetwork(ctx context.Context, accountID, networkID string) error {
|
|
result := s.db.Delete(&networkTypes.Network{}, accountAndIDQueryCondition, accountID, networkID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete network from store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to delete network from store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewNetworkNotFoundError(networkID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var netRouters []*routerTypes.NetworkRouter
|
|
result := tx.
|
|
Find(&netRouters, "account_id = ? AND network_id = ?", accountID, netID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get network routers from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get network routers from store")
|
|
}
|
|
|
|
return netRouters, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var netRouters []*routerTypes.NetworkRouter
|
|
result := tx.
|
|
Find(&netRouters, accountIDCondition, accountID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get network routers from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get network routers from store")
|
|
}
|
|
|
|
return netRouters, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var netRouter *routerTypes.NetworkRouter
|
|
result := tx.
|
|
Take(&netRouter, accountAndIDQueryCondition, accountID, routerID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewNetworkRouterNotFoundError(routerID)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get network router from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get network router from store")
|
|
}
|
|
|
|
return netRouter, nil
|
|
}
|
|
|
|
func (s *SqlStore) SaveNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error {
|
|
result := s.db.Save(router)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to save network router to store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to save network router to store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, accountID, routerID string) error {
|
|
result := s.db.Delete(&routerTypes.NetworkRouter{}, accountAndIDQueryCondition, accountID, routerID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete network router from store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to delete network router from store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewNetworkRouterNotFoundError(routerID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*resourceTypes.NetworkResource, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var netResources []*resourceTypes.NetworkResource
|
|
result := tx.
|
|
Find(&netResources, "account_id = ? AND network_id = ?", accountID, networkID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get network resources from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get network resources from store")
|
|
}
|
|
|
|
return netResources, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var netResources []*resourceTypes.NetworkResource
|
|
result := tx.
|
|
Find(&netResources, accountIDCondition, accountID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get network resources from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get network resources from store")
|
|
}
|
|
|
|
return netResources, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var netResources *resourceTypes.NetworkResource
|
|
result := tx.
|
|
Take(&netResources, accountAndIDQueryCondition, accountID, resourceID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewNetworkResourceNotFoundError(resourceID)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get network resource from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get network resource from store")
|
|
}
|
|
|
|
return netResources, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var netResources *resourceTypes.NetworkResource
|
|
result := tx.
|
|
Take(&netResources, "account_id = ? AND name = ?", accountID, resourceName)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewNetworkResourceNotFoundError(resourceName)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get network resource from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get network resource from store")
|
|
}
|
|
|
|
return netResources, nil
|
|
}
|
|
|
|
func (s *SqlStore) SaveNetworkResource(ctx context.Context, resource *resourceTypes.NetworkResource) error {
|
|
result := s.db.Save(resource)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to save network resource to store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to save network resource to store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) DeleteNetworkResource(ctx context.Context, accountID, resourceID string) error {
|
|
result := s.db.Delete(&resourceTypes.NetworkResource{}, accountAndIDQueryCondition, accountID, resourceID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete network resource from store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to delete network resource from store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewNetworkResourceNotFoundError(resourceID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
|
|
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var pat types.PersonalAccessToken
|
|
result := tx.Take(&pat, "hashed_token = ?", hashedToken)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewPATNotFoundError(hashedToken)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get pat by hash from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get pat by hash from store")
|
|
}
|
|
|
|
return &pat, nil
|
|
}
|
|
|
|
// GetPATByID retrieves a personal access token by its ID and user ID.
|
|
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var pat types.PersonalAccessToken
|
|
result := tx.
|
|
Take(&pat, "id = ? AND user_id = ?", patID, userID)
|
|
if err := result.Error; err != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewPATNotFoundError(patID)
|
|
}
|
|
log.WithContext(ctx).Errorf("failed to get pat from the store: %s", err)
|
|
return nil, status.Errorf(status.Internal, "failed to get pat from store")
|
|
}
|
|
|
|
return &pat, nil
|
|
}
|
|
|
|
// GetUserPATs retrieves personal access tokens for a user.
|
|
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var pats []*types.PersonalAccessToken
|
|
result := tx.Find(&pats, "user_id = ?", userID)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to get user pat's from the store: %s", err)
|
|
return nil, status.Errorf(status.Internal, "failed to get user pat's from store")
|
|
}
|
|
|
|
return pats, nil
|
|
}
|
|
|
|
// MarkPATUsed marks a personal access token as used.
|
|
func (s *SqlStore) MarkPATUsed(ctx context.Context, patID string) error {
|
|
patCopy := types.PersonalAccessToken{
|
|
LastUsed: util.ToPtr(time.Now().UTC()),
|
|
}
|
|
|
|
fieldsToUpdate := []string{"last_used"}
|
|
result := s.db.Select(fieldsToUpdate).
|
|
Where(idQueryCondition, patID).Updates(&patCopy)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to mark pat as used: %s", result.Error)
|
|
return status.Errorf(status.Internal, "failed to mark pat as used")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewPATNotFoundError(patID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SavePAT saves a personal access token to the database.
|
|
func (s *SqlStore) SavePAT(ctx context.Context, pat *types.PersonalAccessToken) error {
|
|
result := s.db.Save(pat)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err)
|
|
return status.Errorf(status.Internal, "failed to save pat to store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeletePAT deletes a personal access token from the database.
|
|
func (s *SqlStore) DeletePAT(ctx context.Context, userID, patID string) error {
|
|
result := s.db.Delete(&types.PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID)
|
|
if err := result.Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete pat from the store: %s", err)
|
|
return status.Errorf(status.Internal, "failed to delete pat from store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewPATNotFoundError(patID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetProxyAccessTokenByHashedToken retrieves a proxy access token by its hashed value.
|
|
func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var token types.ProxyAccessToken
|
|
result := tx.Take(&token, "hashed_token = ?", hashedToken)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "proxy access token not found")
|
|
}
|
|
return nil, status.Errorf(status.Internal, "get proxy access token: %v", result.Error)
|
|
}
|
|
|
|
return &token, nil
|
|
}
|
|
|
|
// GetAllProxyAccessTokens retrieves all proxy access tokens.
|
|
func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var tokens []*types.ProxyAccessToken
|
|
result := tx.Find(&tokens)
|
|
if result.Error != nil {
|
|
return nil, status.Errorf(status.Internal, "get proxy access tokens: %v", result.Error)
|
|
}
|
|
|
|
return tokens, nil
|
|
}
|
|
|
|
// SaveProxyAccessToken saves a proxy access token to the database.
|
|
func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error {
|
|
if result := s.db.Create(token); result.Error != nil {
|
|
return status.Errorf(status.Internal, "save proxy access token: %v", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RevokeProxyAccessToken revokes a proxy access token by its ID.
|
|
func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
|
|
result := s.db.Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true)
|
|
if result.Error != nil {
|
|
return status.Errorf(status.Internal, "revoke proxy access token: %v", result.Error)
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.Errorf(status.NotFound, "proxy access token not found")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
|
|
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
|
|
result := s.db.Model(&types.ProxyAccessToken{}).
|
|
Where(idQueryCondition, tokenID).
|
|
Update("last_used", time.Now().UTC())
|
|
if result.Error != nil {
|
|
return status.Errorf(status.Internal, "mark proxy access token as used: %v", result.Error)
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.Errorf(status.NotFound, "proxy access token not found")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
jsonValue := fmt.Sprintf(`"%s"`, ip.String())
|
|
|
|
var peer nbpeer.Peer
|
|
result := tx.
|
|
Take(&peer, "account_id = ? AND ip = ?", accountID, jsonValue)
|
|
if result.Error != nil {
|
|
// no logging here
|
|
return nil, status.Errorf(status.Internal, "failed to get peer from store")
|
|
}
|
|
|
|
return &peer, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var peerID string
|
|
result := tx.Model(&nbpeer.Peer{}).
|
|
Select("id").
|
|
// Where(" = ?", hostname).
|
|
Where("account_id = ? AND dns_label = ?", accountID, hostname).
|
|
Limit(1).
|
|
Scan(&peerID)
|
|
|
|
if peerID == "" {
|
|
return "", gorm.ErrRecordNotFound
|
|
}
|
|
|
|
return peerID, result.Error
|
|
}
|
|
|
|
func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) {
|
|
var count int64
|
|
result := s.db.Model(&types.Account{}).
|
|
Where("domain = ? AND domain_category = ?",
|
|
strings.ToLower(domain), types.PrivateCategory,
|
|
).Count(&count)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to count accounts by private domain %s: %s", domain, result.Error)
|
|
return 0, status.Errorf(status.Internal, "failed to count accounts by private domain")
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var peers []types.GroupPeer
|
|
result := tx.Find(&peers, accountIDCondition, accountID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get account group peers from store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get account group peers from store")
|
|
}
|
|
|
|
groupPeers := make(map[string]map[string]struct{})
|
|
for _, peer := range peers {
|
|
if _, exists := groupPeers[peer.GroupID]; !exists {
|
|
groupPeers[peer.GroupID] = make(map[string]struct{})
|
|
}
|
|
groupPeers[peer.GroupID][peer.PeerID] = struct{}{}
|
|
}
|
|
|
|
return groupPeers, nil
|
|
}
|
|
|
|
func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
|
|
var info types.PrimaryAccountInfo
|
|
result := s.db.Model(&types.Account{}).
|
|
Select("is_domain_primary_account, domain").
|
|
Where(idQueryCondition, accountID).
|
|
Take(&info)
|
|
|
|
if result.Error != nil {
|
|
return false, "", status.Errorf(status.Internal, "failed to get account info: %v", result.Error)
|
|
}
|
|
|
|
return info.IsDomainPrimaryAccount, info.Domain, nil
|
|
}
|
|
|
|
func (s *SqlStore) MarkAccountPrimary(ctx context.Context, accountID string) error {
|
|
result := s.db.Model(&types.Account{}).
|
|
Where(idQueryCondition, accountID).
|
|
Update("is_domain_primary_account", true)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to mark account as primary: %s", result.Error)
|
|
return status.Errorf(status.Internal, "failed to mark account as primary")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewAccountNotFoundError(accountID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type accountNetworkPatch struct {
|
|
Network *types.Network `gorm:"embedded;embeddedPrefix:network_"`
|
|
}
|
|
|
|
func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error {
|
|
patch := accountNetworkPatch{
|
|
Network: &types.Network{Net: ipNet},
|
|
}
|
|
|
|
result := s.db.
|
|
Model(&types.Account{}).
|
|
Where(idQueryCondition, accountID).
|
|
Updates(&patch)
|
|
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to update account network: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to update account network")
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return status.NewAccountNotFoundError(accountID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) {
|
|
if len(groupIDs) == 0 {
|
|
return []*nbpeer.Peer{}, nil
|
|
}
|
|
|
|
var peers []*nbpeer.Peer
|
|
peerIDsSubquery := s.db.Model(&types.GroupPeer{}).
|
|
Select("DISTINCT peer_id").
|
|
Where("account_id = ? AND group_id IN ?", accountID, groupIDs)
|
|
|
|
result := s.db.Where("id IN (?)", peerIDsSubquery).Find(&peers)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get peers by group IDs: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get peers by group IDs")
|
|
}
|
|
|
|
return peers, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var userID string
|
|
result := tx.Model(&nbpeer.Peer{}).
|
|
Select("user_id").
|
|
Take(&userID, GetKeyQueryCondition(s), peerKey)
|
|
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return "", status.Errorf(status.NotFound, "peer not found: index lookup failed")
|
|
}
|
|
return "", status.Errorf(status.Internal, "failed to get user ID by peer key")
|
|
}
|
|
|
|
return userID, nil
|
|
}
|
|
|
|
// GetPeerAuthInfoByPubKey returns the user_id and account_id for a peer in a
|
|
// single SELECT. Used by the Sync hot path to replace the back-to-back
|
|
// GetUserIDByPeerKey + GetAccountIDByPeerPubKey calls.
|
|
func (s *SqlStore) GetPeerAuthInfoByPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, string, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var row struct {
|
|
UserID string
|
|
AccountID string
|
|
}
|
|
result := tx.Model(&nbpeer.Peer{}).
|
|
Select("user_id", "account_id").
|
|
Take(&row, GetKeyQueryCondition(s), peerKey)
|
|
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return "", "", status.Errorf(status.NotFound, "peer not found: index lookup failed")
|
|
}
|
|
return "", "", status.Errorf(status.Internal, "failed to get peer auth info by peer key")
|
|
}
|
|
|
|
return row.UserID, row.AccountID, nil
|
|
}
|
|
|
|
func (s *SqlStore) CreateZone(ctx context.Context, zone *zones.Zone) error {
|
|
result := s.db.Create(zone)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to create zone to store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to create zone to store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) UpdateZone(ctx context.Context, zone *zones.Zone) error {
|
|
result := s.db.Select("*").Save(zone)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to update zone to store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to update zone to store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) DeleteZone(ctx context.Context, accountID, zoneID string) error {
|
|
result := s.db.Delete(&zones.Zone{}, accountAndIDQueryCondition, accountID, zoneID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete zone from store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to delete zone from store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewZoneNotFoundError(zoneID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) GetZoneByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) (*zones.Zone, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var zone *zones.Zone
|
|
result := tx.Preload("Records").Take(&zone, accountAndIDQueryCondition, accountID, zoneID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewZoneNotFoundError(zoneID)
|
|
}
|
|
|
|
log.WithContext(ctx).Errorf("failed to get zone from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get zone from store")
|
|
}
|
|
|
|
return zone, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetZoneByDomain(ctx context.Context, accountID, domain string) (*zones.Zone, error) {
|
|
var zone *zones.Zone
|
|
result := s.db.Where("account_id = ? AND domain = ?", accountID, domain).First(&zone)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewZoneNotFoundError(domain)
|
|
}
|
|
|
|
log.WithContext(ctx).Errorf("failed to get zone by domain from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get zone by domain from store")
|
|
}
|
|
|
|
return zone, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountZones(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*zones.Zone, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var zones []*zones.Zone
|
|
result := tx.Preload("Records").Find(&zones, accountIDCondition, accountID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get zones from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get zones from store")
|
|
}
|
|
|
|
return zones, nil
|
|
}
|
|
|
|
func (s *SqlStore) CreateDNSRecord(ctx context.Context, record *records.Record) error {
|
|
result := s.db.Create(record)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to create dns record to store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to create dns record to store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) UpdateDNSRecord(ctx context.Context, record *records.Record) error {
|
|
result := s.db.Select("*").Save(record)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to update dns record to store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to update dns record to store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) DeleteDNSRecord(ctx context.Context, accountID, zoneID, recordID string) error {
|
|
result := s.db.Delete(&records.Record{}, "account_id = ? AND zone_id = ? AND id = ?", accountID, zoneID, recordID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete dns record from store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to delete dns record from store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.NewDNSRecordNotFoundError(recordID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var record *records.Record
|
|
result := tx.Where("account_id = ? AND zone_id = ? AND id = ?", accountID, zoneID, recordID).Take(&record)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.NewDNSRecordNotFoundError(recordID)
|
|
}
|
|
|
|
log.WithContext(ctx).Errorf("failed to get dns record from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get dns record from store")
|
|
}
|
|
|
|
return record, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetZoneDNSRecords(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) ([]*records.Record, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var recordsList []*records.Record
|
|
result := tx.Where("account_id = ? AND zone_id = ?", accountID, zoneID).Find(&recordsList)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get zone dns records from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get zone dns records from store")
|
|
}
|
|
|
|
return recordsList, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetZoneDNSRecordsByName(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, name string) ([]*records.Record, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var recordsList []*records.Record
|
|
result := tx.Where("account_id = ? AND zone_id = ? AND name = ?", accountID, zoneID, name).Find(&recordsList)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get zone dns records by name from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get zone dns records by name from store")
|
|
}
|
|
|
|
return recordsList, nil
|
|
}
|
|
|
|
func (s *SqlStore) DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID string) error {
|
|
result := s.db.Delete(&records.Record{}, "account_id = ? AND zone_id = ?", accountID, zoneID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete zone dns records from store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to delete zone dns records from store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var peerID string
|
|
result := tx.Model(&nbpeer.Peer{}).
|
|
Select("id").
|
|
Where(GetKeyQueryCondition(s), key).
|
|
Limit(1).
|
|
Scan(&peerID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get peer ID by key: %s", result.Error)
|
|
return "", status.Errorf(status.Internal, "failed to get peer ID by key")
|
|
}
|
|
|
|
return peerID, nil
|
|
}
|
|
|
|
func (s *SqlStore) CreateService(ctx context.Context, service *rpservice.Service) error {
|
|
serviceCopy := service.Copy()
|
|
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return fmt.Errorf("encrypt service data: %w", err)
|
|
}
|
|
result := s.db.Create(serviceCopy)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to create service to store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to create service to store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) UpdateService(ctx context.Context, service *rpservice.Service) error {
|
|
serviceCopy := service.Copy()
|
|
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return fmt.Errorf("encrypt service data: %w", err)
|
|
}
|
|
|
|
// Create target type instance outside transaction to avoid variable shadowing
|
|
targetType := &rpservice.Target{}
|
|
|
|
// Use a transaction to ensure atomic updates of the service and its targets
|
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
|
// Delete existing targets
|
|
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(targetType).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
// Update the service and create new targets
|
|
if err := tx.Session(&gorm.Session{FullSaveAssociations: true}).Save(serviceCopy).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
log.WithContext(ctx).Errorf("failed to update service to store: %v", err)
|
|
return status.Errorf(status.Internal, "failed to update service to store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error {
|
|
result := s.db.Delete(&rpservice.Service{}, accountAndIDQueryCondition, accountID, serviceID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to delete service from store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.Errorf(status.NotFound, "service %s not found", serviceID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error {
|
|
result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ? AND id = ?", accountID, serviceID, targetID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete target from store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to delete target from store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.Errorf(status.NotFound, "target not found for service %s", serviceID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SqlStore) DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error {
|
|
result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ?", accountID, serviceID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete targets from store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to delete targets from store")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetTargetsByServiceID retrieves all targets for a given service
|
|
func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error) {
|
|
var targets []*rpservice.Target
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
result := tx.Where("account_id = ? AND service_id = ?", accountID, serviceID).Find(&targets)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get targets from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get targets from store")
|
|
}
|
|
|
|
return targets, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) {
|
|
tx := s.db.Preload("Targets")
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var service *rpservice.Service
|
|
result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "service %s not found", serviceID)
|
|
}
|
|
|
|
log.WithContext(ctx).Errorf("failed to get service from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get service from store")
|
|
}
|
|
|
|
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return nil, fmt.Errorf("decrypt service data: %w", err)
|
|
}
|
|
|
|
return service, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
|
var service *rpservice.Service
|
|
result := s.db.Preload("Targets").Where("domain = ?", domain).First(&service)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "service with domain %s not found", domain)
|
|
}
|
|
|
|
log.WithContext(ctx).Errorf("failed to get service by domain from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get service by domain from store")
|
|
}
|
|
|
|
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return nil, fmt.Errorf("decrypt service data: %w", err)
|
|
}
|
|
|
|
return service, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) {
|
|
tx := s.db.Preload("Targets")
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var serviceList []*rpservice.Service
|
|
result := tx.Find(&serviceList)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get services from store")
|
|
}
|
|
|
|
for _, service := range serviceList {
|
|
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return nil, fmt.Errorf("decrypt service data: %w", err)
|
|
}
|
|
}
|
|
|
|
return serviceList, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) {
|
|
tx := s.db.Preload("Targets")
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var serviceList []*rpservice.Service
|
|
result := tx.Find(&serviceList, accountIDCondition, accountID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get services from store")
|
|
}
|
|
|
|
for _, service := range serviceList {
|
|
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
|
return nil, fmt.Errorf("decrypt service data: %w", err)
|
|
}
|
|
}
|
|
|
|
return serviceList, nil
|
|
}
|
|
|
|
// RenewEphemeralService updates the last_renewed_at timestamp for an ephemeral service.
|
|
func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error {
|
|
result := s.db.Model(&rpservice.Service{}).
|
|
Where("id = ? AND account_id = ? AND source_peer = ? AND source = ?", serviceID, accountID, peerID, rpservice.SourceEphemeral).
|
|
Update("meta_last_renewed_at", time.Now())
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to renew ephemeral service: %v", result.Error)
|
|
return status.Errorf(status.Internal, "renew ephemeral service")
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return status.Errorf(status.NotFound, "no active expose session for service %s", serviceID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetExpiredEphemeralServices returns ephemeral services whose last renewal exceeds the given TTL.
|
|
// Only the fields needed for reaping are selected. The limit parameter caps the batch size to
|
|
// avoid loading too many rows in a single tick. Rows with empty source_peer are excluded to
|
|
// skip malformed legacy data.
|
|
func (s *SqlStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error) {
|
|
cutoff := time.Now().Add(-ttl)
|
|
var services []*rpservice.Service
|
|
result := s.db.
|
|
Select("id", "account_id", "source_peer", "domain").
|
|
Where("source = ? AND source_peer <> '' AND meta_last_renewed_at < ?", rpservice.SourceEphemeral, cutoff).
|
|
Limit(limit).
|
|
Find(&services)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get expired ephemeral services: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "get expired ephemeral services")
|
|
}
|
|
return services, nil
|
|
}
|
|
|
|
// CountEphemeralServicesByPeer returns the count of ephemeral services for a specific peer.
|
|
// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations.
|
|
// The locking is applied via a row-level SELECT ... FOR UPDATE (not on the aggregate) to
|
|
// stay compatible with Postgres, which disallows FOR UPDATE on COUNT(*).
|
|
func (s *SqlStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) {
|
|
if lockStrength == LockingStrengthNone {
|
|
var count int64
|
|
result := s.db.Model(&rpservice.Service{}).
|
|
Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral).
|
|
Count(&count)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error)
|
|
return 0, status.Errorf(status.Internal, "count ephemeral services")
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
var ids []string
|
|
result := s.db.Model(&rpservice.Service{}).
|
|
Clauses(clause.Locking{Strength: string(lockStrength)}).
|
|
Select("id").
|
|
Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral).
|
|
Pluck("id", &ids)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error)
|
|
return 0, status.Errorf(status.Internal, "count ephemeral services")
|
|
}
|
|
return int64(len(ids)), nil
|
|
}
|
|
|
|
// EphemeralServiceExists checks if an ephemeral service exists for the given peer and domain.
|
|
// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations.
|
|
func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
|
|
if lockStrength == LockingStrengthNone {
|
|
var count int64
|
|
result := s.db.Model(&rpservice.Service{}).
|
|
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
|
|
Count(&count)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error)
|
|
return false, status.Errorf(status.Internal, "check ephemeral service existence")
|
|
}
|
|
return count > 0, nil
|
|
}
|
|
|
|
var id string
|
|
result := s.db.Model(&rpservice.Service{}).
|
|
Clauses(clause.Locking{Strength: string(lockStrength)}).
|
|
Select("id").
|
|
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
|
|
Limit(1).
|
|
Pluck("id", &id)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error)
|
|
return false, status.Errorf(status.Internal, "check ephemeral service existence")
|
|
}
|
|
return id != "", nil
|
|
}
|
|
|
|
// GetServicesByClusterAndPort returns services matching the given proxy cluster, mode, and listen port.
|
|
func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var services []*rpservice.Service
|
|
result := tx.Where("proxy_cluster = ? AND mode = ? AND listen_port = ?", proxyCluster, mode, listenPort).Find(&services)
|
|
if result.Error != nil {
|
|
return nil, status.Errorf(status.Internal, "query services by cluster and port")
|
|
}
|
|
|
|
return services, nil
|
|
}
|
|
|
|
// GetServicesByCluster returns all services for the given proxy cluster.
|
|
func (s *SqlStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var services []*rpservice.Service
|
|
result := tx.Where("proxy_cluster = ?", proxyCluster).Find(&services)
|
|
if result.Error != nil {
|
|
return nil, status.Errorf(status.Internal, "query services by cluster")
|
|
}
|
|
return services, nil
|
|
}
|
|
|
|
func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) {
|
|
tx := s.db
|
|
|
|
customDomain := &domain.Domain{}
|
|
result := tx.Take(&customDomain, accountAndIDQueryCondition, accountID, domainID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "custom domain %s not found", domainID)
|
|
}
|
|
|
|
log.WithContext(ctx).Errorf("failed to get custom domain from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get custom domain from store")
|
|
}
|
|
|
|
return customDomain, nil
|
|
}
|
|
|
|
func (s *SqlStore) ListFreeDomains(ctx context.Context, accountID string) ([]string, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (s *SqlStore) ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error) {
|
|
tx := s.db
|
|
|
|
var domains []*domain.Domain
|
|
result := tx.Find(&domains, accountIDCondition, accountID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get reverse proxy custom domains from the store: %s", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get reverse proxy custom domains from store")
|
|
}
|
|
|
|
return domains, nil
|
|
}
|
|
|
|
func (s *SqlStore) CreateCustomDomain(ctx context.Context, accountID string, domainName string, targetCluster string, validated bool) (*domain.Domain, error) {
|
|
newDomain := &domain.Domain{
|
|
ID: xid.New().String(), // Generate our own ID because gorm doesn't always configure the database to handle this for us.
|
|
Domain: domainName,
|
|
AccountID: accountID,
|
|
TargetCluster: targetCluster,
|
|
Type: domain.TypeCustom,
|
|
Validated: validated,
|
|
}
|
|
result := s.db.Create(newDomain)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to create reverse proxy custom domain to store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to create reverse proxy custom domain to store")
|
|
}
|
|
|
|
return newDomain, nil
|
|
}
|
|
|
|
func (s *SqlStore) UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error) {
|
|
d.AccountID = accountID
|
|
result := s.db.Select("*").Save(d)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to update reverse proxy custom domain to store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to update reverse proxy custom domain to store")
|
|
}
|
|
|
|
return d, nil
|
|
}
|
|
|
|
func (s *SqlStore) DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error {
|
|
result := s.db.Delete(domain.Domain{}, accountAndIDQueryCondition, accountID, domainID)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete reverse proxy custom domain from store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to delete reverse proxy custom domain from store")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return status.Errorf(status.NotFound, "reverse proxy custom domain %s not found", domainID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CreateAccessLog creates a new access log entry in the database
|
|
func (s *SqlStore) CreateAccessLog(ctx context.Context, logEntry *accesslogs.AccessLogEntry) error {
|
|
result := s.db.Create(logEntry)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).WithFields(log.Fields{
|
|
"service_id": logEntry.ServiceID,
|
|
"method": logEntry.Method,
|
|
"host": logEntry.Host,
|
|
"path": logEntry.Path,
|
|
}).Errorf("failed to create access log entry in store: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to create access log entry in store")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetAccountAccessLogs retrieves access logs for a given account with pagination and filtering
|
|
func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
|
var logs []*accesslogs.AccessLogEntry
|
|
var totalCount int64
|
|
|
|
baseQuery := s.db.
|
|
Model(&accesslogs.AccessLogEntry{}).
|
|
Where(accountIDCondition, accountID)
|
|
|
|
baseQuery = s.applyAccessLogFilters(baseQuery, filter)
|
|
|
|
if err := baseQuery.Count(&totalCount).Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to count access logs: %v", err)
|
|
return nil, 0, status.Errorf(status.Internal, "failed to count access logs")
|
|
}
|
|
|
|
query := s.db.
|
|
Where(accountIDCondition, accountID)
|
|
|
|
query = s.applyAccessLogFilters(query, filter)
|
|
|
|
sortColumns := filter.GetSortColumn()
|
|
sortOrder := strings.ToUpper(filter.GetSortOrder())
|
|
|
|
var orderClauses []string
|
|
for _, col := range strings.Split(sortColumns, ",") {
|
|
col = strings.TrimSpace(col)
|
|
if col != "" {
|
|
orderClauses = append(orderClauses, col+" "+sortOrder)
|
|
}
|
|
}
|
|
orderClause := strings.Join(orderClauses, ", ")
|
|
|
|
query = query.
|
|
Order(orderClause).
|
|
Limit(filter.GetLimit()).
|
|
Offset(filter.GetOffset())
|
|
|
|
if lockStrength != LockingStrengthNone {
|
|
query = query.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
result := query.Find(&logs)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get access logs from store: %v", result.Error)
|
|
return nil, 0, status.Errorf(status.Internal, "failed to get access logs from store")
|
|
}
|
|
|
|
return logs, totalCount, nil
|
|
}
|
|
|
|
// DeleteOldAccessLogs deletes all access logs older than the specified time
|
|
func (s *SqlStore) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) {
|
|
result := s.db.
|
|
Where("timestamp < ?", olderThan).
|
|
Delete(&accesslogs.AccessLogEntry{})
|
|
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to delete old access logs: %v", result.Error)
|
|
return 0, status.Errorf(status.Internal, "failed to delete old access logs")
|
|
}
|
|
|
|
return result.RowsAffected, nil
|
|
}
|
|
|
|
// applyAccessLogFilters applies filter conditions to the query
|
|
func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.AccessLogFilter) *gorm.DB {
|
|
if filter.Search != nil {
|
|
searchPattern := "%" + *filter.Search + "%"
|
|
query = query.Where(
|
|
"id LIKE ? OR location_connection_ip LIKE ? OR host LIKE ? OR path LIKE ? OR CONCAT(host, path) LIKE ? OR user_id IN (SELECT id FROM users WHERE email LIKE ? OR name LIKE ?)",
|
|
searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern,
|
|
)
|
|
}
|
|
|
|
if filter.SourceIP != nil {
|
|
query = query.Where("location_connection_ip = ?", *filter.SourceIP)
|
|
}
|
|
|
|
if filter.Host != nil {
|
|
query = query.Where("host = ?", *filter.Host)
|
|
}
|
|
|
|
if filter.Path != nil {
|
|
// Support LIKE pattern for path filtering
|
|
query = query.Where("path LIKE ?", "%"+*filter.Path+"%")
|
|
}
|
|
|
|
if filter.UserID != nil {
|
|
query = query.Where("user_id = ?", *filter.UserID)
|
|
}
|
|
|
|
if filter.Method != nil {
|
|
query = query.Where("method = ?", *filter.Method)
|
|
}
|
|
|
|
if filter.Status != nil {
|
|
switch *filter.Status {
|
|
case "success":
|
|
query = query.Where("status_code >= ? AND status_code < ?", 200, 400)
|
|
case "failed":
|
|
query = query.Where("status_code < ? OR status_code >= ?", 200, 400)
|
|
}
|
|
}
|
|
|
|
if filter.StatusCode != nil {
|
|
query = query.Where("status_code = ?", *filter.StatusCode)
|
|
}
|
|
|
|
if filter.StartDate != nil {
|
|
query = query.Where("timestamp >= ?", *filter.StartDate)
|
|
}
|
|
|
|
if filter.EndDate != nil {
|
|
query = query.Where("timestamp <= ?", *filter.EndDate)
|
|
}
|
|
|
|
return query
|
|
}
|
|
|
|
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error) {
|
|
tx := s.db
|
|
if lockStrength != LockingStrengthNone {
|
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
|
}
|
|
|
|
var target *rpservice.Target
|
|
result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, status.Errorf(status.NotFound, "service target with ID %s not found", targetID)
|
|
}
|
|
|
|
log.WithContext(ctx).Errorf("failed to get service target from store: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get service target from store")
|
|
}
|
|
|
|
return target, nil
|
|
}
|
|
|
|
// SaveProxy saves or updates a proxy in the database
|
|
func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
|
result := s.db.Save(p)
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to save proxy")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist
|
|
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
|
now := time.Now()
|
|
|
|
result := s.db.
|
|
Model(&proxy.Proxy{}).
|
|
Where("id = ? AND status = ?", proxyID, "connected").
|
|
Update("last_seen", now)
|
|
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to update proxy heartbeat")
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
p := &proxy.Proxy{
|
|
ID: proxyID,
|
|
ClusterAddress: clusterAddress,
|
|
IPAddress: ipAddress,
|
|
LastSeen: now,
|
|
ConnectedAt: &now,
|
|
Status: "connected",
|
|
}
|
|
if err := s.db.Save(p).Error; err != nil {
|
|
log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err)
|
|
return status.Errorf(status.Internal, "failed to create proxy on heartbeat")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies
|
|
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
|
var addresses []string
|
|
|
|
result := s.db.
|
|
Model(&proxy.Proxy{}).
|
|
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
|
|
Distinct("cluster_address").
|
|
Pluck("cluster_address", &addresses)
|
|
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses")
|
|
}
|
|
|
|
return addresses, nil
|
|
}
|
|
|
|
// GetActiveProxyClusters returns all active proxy clusters with their connected proxy count.
|
|
func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
|
|
var clusters []proxy.Cluster
|
|
|
|
result := s.db.Model(&proxy.Proxy{}).
|
|
Select("cluster_address as address, COUNT(*) as connected_proxies").
|
|
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
|
|
Group("cluster_address").
|
|
Scan(&clusters)
|
|
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", result.Error)
|
|
return nil, status.Errorf(status.Internal, "get active proxy clusters")
|
|
}
|
|
|
|
return clusters, nil
|
|
}
|
|
|
|
// proxyActiveThreshold is the maximum age of a heartbeat for a proxy to be
|
|
// considered active. Must be at least 2x the heartbeat interval (1 min).
|
|
const proxyActiveThreshold = 2 * time.Minute
|
|
|
|
var validCapabilityColumns = map[string]struct{}{
|
|
"supports_custom_ports": {},
|
|
"require_subdomain": {},
|
|
"supports_crowdsec": {},
|
|
}
|
|
|
|
// GetClusterSupportsCustomPorts returns whether any active proxy in the cluster
|
|
// supports custom ports. Returns nil when no proxy reported the capability.
|
|
func (s *SqlStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
|
return s.getClusterCapability(ctx, clusterAddr, "supports_custom_ports")
|
|
}
|
|
|
|
// GetClusterRequireSubdomain returns whether any active proxy in the cluster
|
|
// requires a subdomain. Returns nil when no proxy reported the capability.
|
|
func (s *SqlStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
|
|
return s.getClusterCapability(ctx, clusterAddr, "require_subdomain")
|
|
}
|
|
|
|
// GetClusterSupportsCrowdSec returns whether all active proxies in the cluster
|
|
// have CrowdSec configured. Returns nil when no proxy reported the capability.
|
|
// Unlike other capabilities that use ANY-true (for rolling upgrades), CrowdSec
|
|
// requires unanimous support: a single unconfigured proxy would let requests
|
|
// bypass reputation checks.
|
|
func (s *SqlStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
|
|
return s.getClusterUnanimousCapability(ctx, clusterAddr, "supports_crowdsec")
|
|
}
|
|
|
|
// getClusterUnanimousCapability returns an aggregated boolean capability
|
|
// requiring all active proxies in the cluster to report true.
|
|
func (s *SqlStore) getClusterUnanimousCapability(ctx context.Context, clusterAddr, column string) *bool {
|
|
if _, ok := validCapabilityColumns[column]; !ok {
|
|
log.WithContext(ctx).Errorf("invalid capability column: %s", column)
|
|
return nil
|
|
}
|
|
|
|
var result struct {
|
|
Total int64
|
|
Reported int64
|
|
AllTrue bool
|
|
}
|
|
|
|
// All active proxies must have reported the capability (no NULLs) and all
|
|
// must report true. A single unreported or false proxy means the cluster
|
|
// does not unanimously support the capability.
|
|
err := s.db.WithContext(ctx).
|
|
Model(&proxy.Proxy{}).
|
|
Select("COUNT(*) AS total, "+
|
|
"COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) AS reported, "+
|
|
"COUNT(*) > 0 AND COUNT(*) = COUNT(CASE WHEN "+column+" = true THEN 1 END) AS all_true").
|
|
Where("cluster_address = ? AND status = ? AND last_seen > ?",
|
|
clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)).
|
|
Scan(&result).Error
|
|
|
|
if err != nil {
|
|
log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err)
|
|
return nil
|
|
}
|
|
|
|
if result.Total == 0 || result.Reported == 0 {
|
|
return nil
|
|
}
|
|
|
|
// If any proxy has not reported (NULL), we can't confirm unanimous support.
|
|
if result.Reported < result.Total {
|
|
v := false
|
|
return &v
|
|
}
|
|
|
|
return &result.AllTrue
|
|
}
|
|
|
|
// getClusterCapability returns an aggregated boolean capability for the given
|
|
// cluster. It checks active (connected, recently seen) proxies and returns:
|
|
// - *true if any proxy in the cluster has the capability set to true,
|
|
// - *false if at least one proxy reported but none set it to true,
|
|
// - nil if no proxy reported the capability at all.
|
|
func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column string) *bool {
|
|
if _, ok := validCapabilityColumns[column]; !ok {
|
|
log.WithContext(ctx).Errorf("invalid capability column: %s", column)
|
|
return nil
|
|
}
|
|
|
|
var result struct {
|
|
HasCapability bool
|
|
AnyTrue bool
|
|
}
|
|
|
|
err := s.db.
|
|
Model(&proxy.Proxy{}).
|
|
Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+
|
|
"COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true").
|
|
Where("cluster_address = ? AND status = ? AND last_seen > ?",
|
|
clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)).
|
|
Scan(&result).Error
|
|
|
|
if err != nil {
|
|
log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err)
|
|
return nil
|
|
}
|
|
|
|
if !result.HasCapability {
|
|
return nil
|
|
}
|
|
|
|
return &result.AnyTrue
|
|
}
|
|
|
|
// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration
|
|
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
|
|
cutoffTime := time.Now().Add(-inactivityDuration)
|
|
|
|
result := s.db.
|
|
Where("last_seen < ?", cutoffTime).
|
|
Delete(&proxy.Proxy{})
|
|
|
|
if result.Error != nil {
|
|
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", result.Error)
|
|
return status.Errorf(status.Internal, "failed to cleanup stale proxies")
|
|
}
|
|
|
|
if result.RowsAffected > 0 {
|
|
log.WithContext(ctx).Infof("Cleaned up %d stale proxies", result.RowsAffected)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetRoutingPeerNetworks returns the distinct network names where the peer is assigned as a routing peer
|
|
// in an enabled network router, either directly or via peer groups.
|
|
func (s *SqlStore) GetRoutingPeerNetworks(_ context.Context, accountID, peerID string) ([]string, error) {
|
|
var routers []*routerTypes.NetworkRouter
|
|
if err := s.db.Select("peer, peer_groups, network_id").Where("account_id = ? AND enabled = true", accountID).Find(&routers).Error; err != nil {
|
|
return nil, status.Errorf(status.Internal, "failed to get enabled routers: %v", err)
|
|
}
|
|
|
|
if len(routers) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
var groupPeers []types.GroupPeer
|
|
if err := s.db.Select("group_id").Where("account_id = ? AND peer_id = ?", accountID, peerID).Find(&groupPeers).Error; err != nil {
|
|
return nil, status.Errorf(status.Internal, "failed to get peer group memberships: %v", err)
|
|
}
|
|
|
|
groupSet := make(map[string]struct{}, len(groupPeers))
|
|
for _, gp := range groupPeers {
|
|
groupSet[gp.GroupID] = struct{}{}
|
|
}
|
|
|
|
networkIDs := make(map[string]struct{})
|
|
for _, r := range routers {
|
|
if r.Peer == peerID {
|
|
networkIDs[r.NetworkID] = struct{}{}
|
|
} else if r.Peer == "" {
|
|
for _, pg := range r.PeerGroups {
|
|
if _, ok := groupSet[pg]; ok {
|
|
networkIDs[r.NetworkID] = struct{}{}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(networkIDs) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
ids := make([]string, 0, len(networkIDs))
|
|
for id := range networkIDs {
|
|
ids = append(ids, id)
|
|
}
|
|
|
|
var networks []*networkTypes.Network
|
|
if err := s.db.Select("name").Where("account_id = ? AND id IN ?", accountID, ids).Find(&networks).Error; err != nil {
|
|
return nil, status.Errorf(status.Internal, "failed to get networks: %v", err)
|
|
}
|
|
|
|
names := make([]string, 0, len(networks))
|
|
for _, n := range networks {
|
|
names = append(names, n.Name)
|
|
}
|
|
|
|
return names, nil
|
|
}
|