mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Faster server bootstrap by counting accounts rather than fetching all from storage in the account manager instantiation. This change moved the deprecated need to ensure accounts have an All group to tests instead.
563 lines
25 KiB
Go
563 lines
25 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"regexp"
|
|
"runtime"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
log "github.com/sirupsen/logrus"
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/netbirdio/netbird/dns"
|
|
"github.com/netbirdio/netbird/management/server/testutil"
|
|
"github.com/netbirdio/netbird/management/server/types"
|
|
|
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
"github.com/netbirdio/netbird/util"
|
|
|
|
"github.com/netbirdio/netbird/management/server/migration"
|
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
"github.com/netbirdio/netbird/management/server/posture"
|
|
"github.com/netbirdio/netbird/route"
|
|
)
|
|
|
|
type LockingStrength string
|
|
|
|
const (
|
|
LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes.
|
|
LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions.
|
|
LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows.
|
|
LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates.
|
|
)
|
|
|
|
type Store interface {
|
|
GetAccountsCounter(ctx context.Context) (int64, error)
|
|
GetAllAccounts(ctx context.Context) []*types.Account
|
|
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
|
|
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
|
|
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
|
GetAccountByUser(ctx context.Context, userID string) (*types.Account, error)
|
|
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error)
|
|
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
|
GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
|
|
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
|
GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error)
|
|
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
|
|
GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) // todo use key hash later
|
|
GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error)
|
|
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
|
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error)
|
|
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error)
|
|
GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error)
|
|
SaveAccount(ctx context.Context, account *types.Account) error
|
|
DeleteAccount(ctx context.Context, account *types.Account) error
|
|
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
|
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
|
|
|
|
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)
|
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
|
|
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error)
|
|
SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error
|
|
SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error
|
|
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
|
DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error
|
|
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
|
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
|
DeleteTokenID2UserIDIndex(tokenID string) error
|
|
|
|
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
|
|
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
|
|
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error)
|
|
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
|
|
SavePAT(ctx context.Context, strength LockingStrength, pat *types.PersonalAccessToken) error
|
|
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
|
|
|
|
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
|
|
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
|
|
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
|
|
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error)
|
|
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
|
|
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error
|
|
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error
|
|
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
|
|
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
|
|
|
|
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error)
|
|
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error)
|
|
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error
|
|
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error
|
|
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
|
|
|
|
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
|
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
|
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
|
|
GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error)
|
|
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
|
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
|
|
|
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
|
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
|
|
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
|
|
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
|
|
AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error
|
|
RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error
|
|
AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error
|
|
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
|
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
|
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
|
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
|
|
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
|
|
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
|
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
|
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
|
|
SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
|
|
SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error
|
|
SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
|
|
DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
|
|
|
|
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)
|
|
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
|
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error)
|
|
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error)
|
|
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error
|
|
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
|
|
|
|
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
|
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
|
|
|
|
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
|
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
|
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
|
|
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
|
|
|
|
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
|
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
|
|
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error)
|
|
|
|
GetInstallationID() string
|
|
SaveInstallationID(ctx context.Context, ID string) error
|
|
|
|
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
|
|
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
|
|
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
|
|
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
|
|
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
|
AcquireGlobalLock(ctx context.Context) func()
|
|
|
|
// Close should close the store persisting all unsaved data.
|
|
Close(ctx context.Context) error
|
|
// GetStoreEngine should return Engine of the current store implementation.
|
|
// This is also a method of metrics.DataSource interface.
|
|
GetStoreEngine() Engine
|
|
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
|
|
|
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
|
|
GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error)
|
|
SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error
|
|
DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error
|
|
|
|
GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error)
|
|
GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error)
|
|
GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error)
|
|
SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error
|
|
DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error
|
|
|
|
GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error)
|
|
GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error)
|
|
GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error)
|
|
GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error)
|
|
SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error
|
|
DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error
|
|
}
|
|
|
|
type Engine string
|
|
|
|
const (
|
|
FileStoreEngine Engine = "jsonfile"
|
|
SqliteStoreEngine Engine = "sqlite"
|
|
PostgresStoreEngine Engine = "postgres"
|
|
MysqlStoreEngine Engine = "mysql"
|
|
|
|
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
|
mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
|
|
)
|
|
|
|
var supportedEngines = []Engine{SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine}
|
|
|
|
func getStoreEngineFromEnv() Engine {
|
|
// NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file.
|
|
kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE")
|
|
if !ok {
|
|
return ""
|
|
}
|
|
|
|
value := Engine(strings.ToLower(kind))
|
|
if slices.Contains(supportedEngines, value) {
|
|
return value
|
|
}
|
|
|
|
return SqliteStoreEngine
|
|
}
|
|
|
|
// getStoreEngine determines the store engine to use.
|
|
// If no engine is specified, it attempts to retrieve it from the environment.
|
|
// If still not specified, it defaults to using SQLite.
|
|
// Additionally, it handles the migration from a JSON store file to SQLite if applicable.
|
|
func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
|
|
if kind == "" {
|
|
kind = getStoreEngineFromEnv()
|
|
if kind == "" {
|
|
kind = SqliteStoreEngine
|
|
|
|
// Migrate if it is the first run with a JSON file existing and no SQLite file present
|
|
jsonStoreFile := filepath.Join(dataDir, storeFileName)
|
|
sqliteStoreFile := filepath.Join(dataDir, storeSqliteFileName)
|
|
|
|
if util.FileExists(jsonStoreFile) && !util.FileExists(sqliteStoreFile) {
|
|
log.WithContext(ctx).Warnf("unsupported store engine specified, but found %s. Automatically migrating to SQLite.", jsonStoreFile)
|
|
|
|
// Attempt to migrate from JSON store to SQLite
|
|
if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil {
|
|
log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err)
|
|
kind = FileStoreEngine
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return kind
|
|
}
|
|
|
|
// NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics
|
|
func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
|
|
kind = getStoreEngine(ctx, dataDir, kind)
|
|
|
|
if err := checkFileStoreEngine(kind, dataDir); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch kind {
|
|
case SqliteStoreEngine:
|
|
log.WithContext(ctx).Info("using SQLite store engine")
|
|
return NewSqliteStore(ctx, dataDir, metrics)
|
|
case PostgresStoreEngine:
|
|
log.WithContext(ctx).Info("using Postgres store engine")
|
|
return newPostgresStore(ctx, metrics)
|
|
case MysqlStoreEngine:
|
|
log.WithContext(ctx).Info("using MySQL store engine")
|
|
return newMysqlStore(ctx, metrics)
|
|
default:
|
|
return nil, fmt.Errorf("unsupported kind of store: %s", kind)
|
|
}
|
|
}
|
|
|
|
func checkFileStoreEngine(kind Engine, dataDir string) error {
|
|
if kind == FileStoreEngine {
|
|
storeFile := filepath.Join(dataDir, storeFileName)
|
|
if util.FileExists(storeFile) {
|
|
return fmt.Errorf("%s is not supported. Please refer to the documentation for migrating to SQLite: "+
|
|
"https://docs.netbird.io/selfhosted/sqlite-store#migrating-from-json-store-to-sq-lite-store", FileStoreEngine)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// migrate migrates the SQLite database to the latest schema
|
|
func migrate(ctx context.Context, db *gorm.DB) error {
|
|
migrations := getMigrations(ctx)
|
|
|
|
for _, m := range migrations {
|
|
if err := m(db); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getMigrations(ctx context.Context) []migrationFunc {
|
|
return []migrationFunc{
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](ctx, db, "network")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateFieldFromGobToJSON[route.Route, []string](ctx, db, "peer_groups")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "location_connection_ip", "")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "ip", "idx_peers_account_id_ip")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](ctx, db)
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateNewField[resourceTypes.NetworkResource](ctx, db, "enabled", true)
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateNewField[routerTypes.NetworkRouter](ctx, db, "enabled", true)
|
|
},
|
|
}
|
|
}
|
|
|
|
// NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env.
|
|
// Optionally it can load a SQL file to the database. If the filename is empty it will return an empty database
|
|
func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (Store, func(), error) {
|
|
kind := getStoreEngineFromEnv()
|
|
if kind == "" {
|
|
kind = SqliteStoreEngine
|
|
}
|
|
|
|
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
|
|
if runtime.GOOS == "windows" {
|
|
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
|
|
storeStr = storeSqliteFileName
|
|
}
|
|
|
|
file := filepath.Join(dataDir, storeStr)
|
|
db, err := gorm.Open(sqlite.Open(file), getGormConfig())
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
if filename != "" {
|
|
err = loadSQL(db, filename)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to load SQL file: %v", err)
|
|
}
|
|
}
|
|
|
|
store, err := NewSqlStore(ctx, db, SqliteStoreEngine, nil)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to create test store: %v", err)
|
|
}
|
|
|
|
err = addAllGroupToAccount(ctx, store)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to add all group to account: %v", err)
|
|
}
|
|
|
|
return getSqlStoreEngine(ctx, store, kind)
|
|
}
|
|
|
|
func addAllGroupToAccount(ctx context.Context, store Store) error {
|
|
allAccounts := store.GetAllAccounts(ctx)
|
|
for _, account := range allAccounts {
|
|
shouldSave := false
|
|
|
|
_, err := account.GetGroupAll()
|
|
if err != nil {
|
|
if err := account.AddAllGroup(); err != nil {
|
|
return err
|
|
}
|
|
shouldSave = true
|
|
}
|
|
|
|
if shouldSave {
|
|
err = store.SaveAccount(ctx, account)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) {
|
|
var cleanup func()
|
|
var err error
|
|
switch kind {
|
|
case PostgresStoreEngine:
|
|
store, cleanup, err = newReusedPostgresStore(ctx, store, kind)
|
|
case MysqlStoreEngine:
|
|
store, cleanup, err = newReusedMysqlStore(ctx, store, kind)
|
|
default:
|
|
cleanup = func() {
|
|
// sqlite doesn't need to be cleaned up
|
|
}
|
|
}
|
|
if err != nil {
|
|
return nil, cleanup, fmt.Errorf("failed to create test store: %v", err)
|
|
}
|
|
|
|
closeConnection := func() {
|
|
cleanup()
|
|
store.Close(ctx)
|
|
}
|
|
|
|
return store, closeConnection, nil
|
|
}
|
|
|
|
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) {
|
|
if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" {
|
|
var err error
|
|
_, err = testutil.CreatePostgresTestContainer()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
}
|
|
|
|
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
|
if !ok {
|
|
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
|
|
}
|
|
|
|
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err)
|
|
}
|
|
|
|
dsn, cleanup, err := createRandomDB(dsn, db, kind)
|
|
if err != nil {
|
|
return nil, cleanup, err
|
|
}
|
|
|
|
store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil)
|
|
if err != nil {
|
|
return nil, cleanup, err
|
|
}
|
|
|
|
return store, cleanup, nil
|
|
}
|
|
|
|
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) {
|
|
if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" {
|
|
var err error
|
|
_, err = testutil.CreateMysqlTestContainer()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
}
|
|
|
|
dsn, ok := os.LookupEnv(mysqlDsnEnv)
|
|
if !ok {
|
|
return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
|
|
}
|
|
|
|
db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err)
|
|
}
|
|
|
|
dsn, cleanup, err := createRandomDB(dsn, db, kind)
|
|
if err != nil {
|
|
return nil, cleanup, err
|
|
}
|
|
|
|
store, err = NewMysqlStoreFromSqlStore(ctx, store, dsn, nil)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return store, cleanup, nil
|
|
}
|
|
|
|
func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), error) {
|
|
dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_"))
|
|
|
|
if err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)).Error; err != nil {
|
|
return "", nil, fmt.Errorf("failed to create database: %v", err)
|
|
}
|
|
|
|
var err error
|
|
cleanup := func() {
|
|
switch engine {
|
|
case PostgresStoreEngine:
|
|
err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error
|
|
case MysqlStoreEngine:
|
|
// err = killMySQLConnections(dsn, dbName)
|
|
err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error
|
|
}
|
|
if err != nil {
|
|
log.Errorf("failed to drop database %s: %v", dbName, err)
|
|
panic(err)
|
|
}
|
|
sqlDB, _ := db.DB()
|
|
_ = sqlDB.Close()
|
|
}
|
|
|
|
return replaceDBName(dsn, dbName), cleanup, nil
|
|
}
|
|
|
|
func replaceDBName(dsn, newDBName string) string {
|
|
re := regexp.MustCompile(`(?P<pre>[:/@])(?P<dbname>[^/?]+)(?P<post>\?|$)`)
|
|
return re.ReplaceAllString(dsn, `${pre}`+newDBName+`${post}`)
|
|
}
|
|
|
|
func loadSQL(db *gorm.DB, filepath string) error {
|
|
sqlContent, err := os.ReadFile(filepath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
queries := strings.Split(string(sqlContent), ";")
|
|
|
|
for _, query := range queries {
|
|
query = strings.TrimSpace(query)
|
|
if query != "" {
|
|
err := db.Exec(query).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// MigrateFileStoreToSqlite migrates the file store to the SQLite store.
|
|
func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error {
|
|
fileStorePath := path.Join(dataDir, storeFileName)
|
|
if _, err := os.Stat(fileStorePath); errors.Is(err, os.ErrNotExist) {
|
|
return fmt.Errorf("%s doesn't exist, couldn't continue the operation", fileStorePath)
|
|
}
|
|
|
|
sqlStorePath := path.Join(dataDir, storeSqliteFileName)
|
|
if _, err := os.Stat(sqlStorePath); err == nil {
|
|
return fmt.Errorf("%s already exists, couldn't continue the operation", sqlStorePath)
|
|
}
|
|
|
|
fstore, err := NewFileStore(ctx, dataDir, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed creating file store: %s: %v", dataDir, err)
|
|
}
|
|
|
|
fsStoreAccounts := len(fstore.GetAllAccounts(ctx))
|
|
log.WithContext(ctx).Infof("%d account will be migrated from file store %s to sqlite store %s",
|
|
fsStoreAccounts, fileStorePath, sqlStorePath)
|
|
|
|
store, err := NewSqliteStoreFromFileStore(ctx, fstore, dataDir, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed creating file store: %s: %v", dataDir, err)
|
|
}
|
|
|
|
sqliteStoreAccounts := len(store.GetAllAccounts(ctx))
|
|
if fsStoreAccounts != sqliteStoreAccounts {
|
|
return fmt.Errorf("failed to migrate accounts from file to sqlite. Expected accounts: %d, got: %d",
|
|
fsStoreAccounts, sqliteStoreAccounts)
|
|
}
|
|
|
|
return nil
|
|
}
|