mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-19 23:29:56 +00:00
* [management] Fence peer status updates with a session token The connect/disconnect path used a best-effort LastSeen-after-streamStart comparison to decide whether a status update should land. Under contention — a re-sync arriving while the previous stream's disconnect was still in flight, or two management replicas seeing the same peer at once — the check was a read-then-decide-then-write window: any UPDATE in between caused the wrong row to be written. The Go-side time.Now() that fed the comparison also drifted under lock contention, since it was captured seconds before the write actually committed. Replace it with an integer-nanosecond fencing token stored alongside the status. Every gRPC sync stream uses its open time (UnixNano) as its token. Connects only land when the incoming token is strictly greater than the stored one; disconnects only land when the incoming token equals the stored one (i.e. we're the stream that owns the current session). Both are single optimistic-locked UPDATEs — no read-then-write, no transaction wrapper. LastSeen is now written by the database itself (CURRENT_TIMESTAMP). The caller never supplies it, so the value always reflects the real moment of the UPDATE rather than the moment the caller queued the work — which was already off by minutes under heavy lock contention. Side effects (geo lookup, peer-login-expiration scheduling, network-map fan-out) are explicitly documented as running after the fence UPDATE commits, never inside it. Geo also skips the update when realIP equals the stored ConnectionIP, dropping a redundant SavePeerLocation call on same-IP reconnects. Tests cover the three semantic cases (matched disconnect lands, stale disconnect dropped, stale connect dropped) plus a 16-goroutine race test that asserts the highest token always wins. * [management] Add SessionStartedAt to peer status updates Stored `SessionStartedAt` for fencing token propagation across goroutines and updated database queries/functions to handle the new field. Removed outdated geolocation handling logic and adjusted tests for concurrency safety. * Rename `peer_status_required_approval` to `peer_status_requires_approval` in SQL store fields
866 lines
39 KiB
Go
866 lines
39 KiB
Go
package store
|
|
|
|
//go:generate go run github.com/golang/mock/mockgen -package store -destination=store_mock.go -source=./store.go -build_flags=-mod=mod
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"regexp"
|
|
"runtime"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
log "github.com/sirupsen/logrus"
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/netbirdio/netbird/dns"
|
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
|
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
|
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
"github.com/netbirdio/netbird/management/server/testutil"
|
|
"github.com/netbirdio/netbird/management/server/types"
|
|
"github.com/netbirdio/netbird/util"
|
|
"github.com/netbirdio/netbird/util/crypt"
|
|
|
|
"github.com/netbirdio/netbird/management/server/migration"
|
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
"github.com/netbirdio/netbird/management/server/posture"
|
|
"github.com/netbirdio/netbird/route"
|
|
)
|
|
|
|
type LockingStrength string
|
|
|
|
const (
|
|
LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes.
|
|
LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions.
|
|
LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows.
|
|
LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates.
|
|
LockingStrengthNone LockingStrength = "NONE" // No locking, allowing all transactions to proceed without restrictions.
|
|
)
|
|
|
|
type Store interface {
|
|
GetAccountsCounter(ctx context.Context) (int64, error)
|
|
GetAllAccounts(ctx context.Context) []*types.Account
|
|
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
|
|
GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error)
|
|
GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error)
|
|
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
|
|
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
|
GetAccountByUser(ctx context.Context, userID string) (*types.Account, error)
|
|
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error)
|
|
GetAnyAccountID(ctx context.Context) (string, error)
|
|
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
|
GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
|
|
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
|
GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error)
|
|
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
|
|
GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) // todo use key hash later
|
|
GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error)
|
|
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
|
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error)
|
|
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error)
|
|
GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error)
|
|
SaveAccount(ctx context.Context, account *types.Account) error
|
|
DeleteAccount(ctx context.Context, account *types.Account) error
|
|
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
|
SaveDNSSettings(ctx context.Context, accountID string, settings *types.DNSSettings) error
|
|
SaveAccountSettings(ctx context.Context, accountID string, settings *types.Settings) error
|
|
CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error)
|
|
SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error
|
|
|
|
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)
|
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
|
|
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error)
|
|
GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error)
|
|
SaveUsers(ctx context.Context, users []*types.User) error
|
|
SaveUser(ctx context.Context, user *types.User) error
|
|
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
|
DeleteUser(ctx context.Context, accountID, userID string) error
|
|
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
|
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
|
DeleteTokenID2UserIDIndex(tokenID string) error
|
|
|
|
SaveUserInvite(ctx context.Context, invite *types.UserInviteRecord) error
|
|
GetUserInviteByID(ctx context.Context, lockStrength LockingStrength, accountID, inviteID string) (*types.UserInviteRecord, error)
|
|
GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error)
|
|
GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error)
|
|
GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error)
|
|
DeleteUserInvite(ctx context.Context, inviteID string) error
|
|
|
|
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
|
|
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
|
|
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error)
|
|
MarkPATUsed(ctx context.Context, patID string) error
|
|
SavePAT(ctx context.Context, pat *types.PersonalAccessToken) error
|
|
DeletePAT(ctx context.Context, userID, patID string) error
|
|
|
|
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
|
|
GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error)
|
|
GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error)
|
|
GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error)
|
|
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
|
|
SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error
|
|
RevokeProxyAccessToken(ctx context.Context, tokenID string) error
|
|
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
|
|
|
|
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
|
|
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
|
|
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
|
|
GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error)
|
|
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
|
|
CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error
|
|
UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error
|
|
CreateGroup(ctx context.Context, group *types.Group) error
|
|
UpdateGroup(ctx context.Context, group *types.Group) error
|
|
DeleteGroup(ctx context.Context, accountID, groupID string) error
|
|
DeleteGroups(ctx context.Context, accountID string, groupIDs []string) error
|
|
|
|
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error)
|
|
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error)
|
|
CreatePolicy(ctx context.Context, policy *types.Policy) error
|
|
SavePolicy(ctx context.Context, policy *types.Policy) error
|
|
DeletePolicy(ctx context.Context, accountID, policyID string) error
|
|
|
|
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
|
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
|
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
|
|
GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error)
|
|
SavePostureChecks(ctx context.Context, postureCheck *posture.Checks) error
|
|
DeletePostureChecks(ctx context.Context, accountID, postureChecksID string) error
|
|
|
|
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error)
|
|
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
|
AddPeerToGroup(ctx context.Context, accountID, peerId string, groupID string) error
|
|
RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error
|
|
RemovePeerFromAllGroups(ctx context.Context, peerID string) error
|
|
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
|
|
GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error)
|
|
AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error
|
|
RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error
|
|
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
|
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
|
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
|
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
|
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
|
|
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
|
|
GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error)
|
|
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
|
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
|
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
|
|
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
|
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
|
|
// MarkPeerConnectedIfNewerSession sets the peer to connected with the
|
|
// given session token, but only when the stored SessionStartedAt is
|
|
// strictly less than newSessionStartedAt (the sentinel zero counts as
|
|
// "older"). LastSeen is recorded by the database at the moment the
|
|
// row is updated — never by the caller — so it always reflects the
|
|
// real write time even under lock contention.
|
|
// Returns true when the update happened, false when this stream lost
|
|
// the race against a newer session.
|
|
MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error)
|
|
// MarkPeerDisconnectedIfSameSession sets the peer to disconnected and
|
|
// resets SessionStartedAt to zero, but only when the stored
|
|
// SessionStartedAt equals the given sessionStartedAt. LastSeen is
|
|
// recorded by the database. Returns true when the update happened,
|
|
// false when a newer session has taken over.
|
|
MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error)
|
|
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
|
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
|
|
DeletePeer(ctx context.Context, accountID string, peerID string) error
|
|
|
|
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)
|
|
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
|
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error)
|
|
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error)
|
|
SaveSetupKey(ctx context.Context, setupKey *types.SetupKey) error
|
|
DeleteSetupKey(ctx context.Context, accountID, keyID string) error
|
|
|
|
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
|
GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
|
|
SaveRoute(ctx context.Context, route *route.Route) error
|
|
DeleteRoute(ctx context.Context, accountID, routeID string) error
|
|
|
|
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
|
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
|
SaveNameServerGroup(ctx context.Context, nameServerGroup *dns.NameServerGroup) error
|
|
DeleteNameServerGroup(ctx context.Context, accountID, nameServerGroupID string) error
|
|
|
|
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]netip.Addr, error)
|
|
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
|
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error)
|
|
|
|
GetInstallationID() string
|
|
SaveInstallationID(ctx context.Context, ID string) error
|
|
|
|
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
|
AcquireGlobalLock(ctx context.Context) func()
|
|
|
|
// Close should close the store persisting all unsaved data.
|
|
Close(ctx context.Context) error
|
|
// GetStoreEngine should return Engine of the current store implementation.
|
|
// This is also a method of metrics.DataSource interface.
|
|
GetStoreEngine() types.Engine
|
|
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
|
|
|
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
|
|
GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error)
|
|
SaveNetwork(ctx context.Context, network *networkTypes.Network) error
|
|
DeleteNetwork(ctx context.Context, accountID, networkID string) error
|
|
|
|
GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error)
|
|
GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error)
|
|
GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error)
|
|
SaveNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error
|
|
DeleteNetworkRouter(ctx context.Context, accountID, routerID string) error
|
|
|
|
GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error)
|
|
GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error)
|
|
GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error)
|
|
GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error)
|
|
SaveNetworkResource(ctx context.Context, resource *resourceTypes.NetworkResource) error
|
|
DeleteNetworkResource(ctx context.Context, accountID, resourceID string) error
|
|
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
|
|
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
|
|
GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error)
|
|
IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error)
|
|
MarkAccountPrimary(ctx context.Context, accountID string) error
|
|
UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error
|
|
UpdateAccountNetworkV6(ctx context.Context, accountID string, ipNet net.IPNet) error
|
|
GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error)
|
|
|
|
// SetFieldEncrypt sets the field encryptor for encrypting sensitive user data.
|
|
SetFieldEncrypt(enc *crypt.FieldEncrypt)
|
|
GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error)
|
|
|
|
CreateZone(ctx context.Context, zone *zones.Zone) error
|
|
UpdateZone(ctx context.Context, zone *zones.Zone) error
|
|
DeleteZone(ctx context.Context, accountID, zoneID string) error
|
|
GetZoneByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) (*zones.Zone, error)
|
|
GetZoneByDomain(ctx context.Context, accountID, domain string) (*zones.Zone, error)
|
|
GetAccountZones(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*zones.Zone, error)
|
|
|
|
CreateDNSRecord(ctx context.Context, record *records.Record) error
|
|
UpdateDNSRecord(ctx context.Context, record *records.Record) error
|
|
DeleteDNSRecord(ctx context.Context, accountID, zoneID, recordID string) error
|
|
GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error)
|
|
GetZoneDNSRecords(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) ([]*records.Record, error)
|
|
GetZoneDNSRecordsByName(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, name string) ([]*records.Record, error)
|
|
DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID string) error
|
|
CreatePeerJob(ctx context.Context, job *types.Job) error
|
|
CompletePeerJob(ctx context.Context, job *types.Job) error
|
|
GetPeerJobByID(ctx context.Context, accountID, jobID string) (*types.Job, error)
|
|
GetPeerJobs(ctx context.Context, accountID, peerID string) ([]*types.Job, error)
|
|
MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, jobID, reason string) error
|
|
MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
|
|
GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error)
|
|
|
|
CreateService(ctx context.Context, service *rpservice.Service) error
|
|
UpdateService(ctx context.Context, service *rpservice.Service) error
|
|
DeleteService(ctx context.Context, accountID, serviceID string) error
|
|
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error)
|
|
GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error)
|
|
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
|
|
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)
|
|
|
|
RenewEphemeralService(ctx context.Context, accountID, peerID, serviceID string) error
|
|
GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error)
|
|
CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error)
|
|
EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error)
|
|
GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error)
|
|
GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error)
|
|
|
|
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
|
|
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
|
|
ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error)
|
|
CreateCustomDomain(ctx context.Context, accountID string, domainName string, targetCluster string, validated bool) (*domain.Domain, error)
|
|
UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error)
|
|
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
|
|
|
|
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
|
|
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
|
|
DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error)
|
|
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error)
|
|
GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error)
|
|
DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error
|
|
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
|
|
|
|
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
|
|
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
|
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
|
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
|
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
|
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
|
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
|
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
|
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
|
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
|
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
|
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
|
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
|
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
|
|
|
|
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
|
|
|
GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error)
|
|
}
|
|
|
|
const (
|
|
postgresDsnEnv = "NB_STORE_ENGINE_POSTGRES_DSN"
|
|
postgresDsnEnvLegacy = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
|
mysqlDsnEnv = "NB_STORE_ENGINE_MYSQL_DSN"
|
|
mysqlDsnEnvLegacy = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
|
|
)
|
|
|
|
// lookupDSNEnv checks the NB_ env var first, then falls back to the legacy NETBIRD_ env var.
|
|
func lookupDSNEnv(nbKey, legacyKey string) (string, bool) {
|
|
if v, ok := os.LookupEnv(nbKey); ok {
|
|
return v, true
|
|
}
|
|
return os.LookupEnv(legacyKey)
|
|
}
|
|
|
|
var supportedEngines = []types.Engine{types.SqliteStoreEngine, types.PostgresStoreEngine, types.MysqlStoreEngine}
|
|
|
|
func getStoreEngineFromEnv() types.Engine {
|
|
// NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file.
|
|
kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE")
|
|
if !ok {
|
|
return ""
|
|
}
|
|
|
|
value := types.Engine(strings.ToLower(kind))
|
|
if slices.Contains(supportedEngines, value) {
|
|
return value
|
|
}
|
|
|
|
return types.SqliteStoreEngine
|
|
}
|
|
|
|
// getStoreEngine determines the store engine to use.
|
|
// If no engine is specified, it attempts to retrieve it from the environment.
|
|
// If still not specified, it defaults to using SQLite.
|
|
// Additionally, it handles the migration from a JSON store file to SQLite if applicable.
|
|
func getStoreEngine(ctx context.Context, dataDir string, kind types.Engine) types.Engine {
|
|
if kind == "" {
|
|
kind = getStoreEngineFromEnv()
|
|
if kind == "" {
|
|
kind = types.SqliteStoreEngine
|
|
|
|
// Migrate if it is the first run with a JSON file existing and no SQLite file present
|
|
jsonStoreFile := filepath.Join(dataDir, storeFileName)
|
|
sqliteStoreFile := filepath.Join(dataDir, storeSqliteFileName)
|
|
|
|
if util.FileExists(jsonStoreFile) && !util.FileExists(sqliteStoreFile) {
|
|
log.WithContext(ctx).Warnf("unsupported store engine specified, but found %s. Automatically migrating to SQLite.", jsonStoreFile)
|
|
|
|
// Attempt to migratePreAuto from JSON store to SQLite
|
|
if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil {
|
|
log.WithContext(ctx).Errorf("failed to migratePreAuto filestore to SQLite: %v", err)
|
|
kind = types.FileStoreEngine
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return kind
|
|
}
|
|
|
|
// NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics
|
|
func NewStore(ctx context.Context, kind types.Engine, dataDir string, metrics telemetry.AppMetrics, skipMigration bool) (Store, error) {
|
|
kind = getStoreEngine(ctx, dataDir, kind)
|
|
|
|
if err := checkFileStoreEngine(kind, dataDir); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch kind {
|
|
case types.SqliteStoreEngine:
|
|
log.WithContext(ctx).Info("using SQLite store engine")
|
|
return NewSqliteStore(ctx, dataDir, metrics, skipMigration)
|
|
case types.PostgresStoreEngine:
|
|
log.WithContext(ctx).Info("using Postgres store engine")
|
|
return newPostgresStore(ctx, metrics, skipMigration)
|
|
case types.MysqlStoreEngine:
|
|
log.WithContext(ctx).Info("using MySQL store engine")
|
|
return newMysqlStore(ctx, metrics, skipMigration)
|
|
default:
|
|
return nil, fmt.Errorf("unsupported kind of store: %s", kind)
|
|
}
|
|
}
|
|
|
|
func checkFileStoreEngine(kind types.Engine, dataDir string) error {
|
|
if kind == types.FileStoreEngine {
|
|
storeFile := filepath.Join(dataDir, storeFileName)
|
|
if util.FileExists(storeFile) {
|
|
return fmt.Errorf("%s is not supported. Please refer to the documentation for migrating to SQLite: "+
|
|
"https://docs.netbird.io/selfhosted/sqlite-store#migrating-from-json-store-to-sq-lite-store", types.FileStoreEngine)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// migratePreAuto migrates the SQLite database to the latest schema
|
|
func migratePreAuto(ctx context.Context, db *gorm.DB) error {
|
|
migrations := getMigrationsPreAuto(ctx)
|
|
|
|
for _, m := range migrations {
|
|
if err := m(db); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
|
|
return []migrationFunc{
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](ctx, db, "network")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateFieldFromGobToJSON[route.Route, []string](ctx, db, "peer_groups")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "location_connection_ip", "")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "ip", "idx_peers_account_id_ip")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](ctx, db)
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateNewField[resourceTypes.NetworkResource](ctx, db, "enabled", true)
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateNewField[routerTypes.NetworkRouter](ctx, db, "enabled", true)
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.DropIndex[networkTypes.Network](ctx, db, "idx_networks_id")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.DropIndex[resourceTypes.NetworkResource](ctx, db, "idx_network_resources_id")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateNewField[types.User](ctx, db, "name", "")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateNewField[types.User](ctx, db, "email", "")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.RemoveDuplicatePeerKeys(ctx, db)
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.CleanupOrphanedResources[rpservice.Service, types.Account](ctx, db, "account_id")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.CleanupOrphanedResources[domain.Domain, types.Account](ctx, db, "account_id")
|
|
},
|
|
}
|
|
}
|
|
|
|
// migratePostAuto migrates the SQLite database to the latest schema
|
|
func migratePostAuto(ctx context.Context, db *gorm.DB) error {
|
|
migrations := getMigrationsPostAuto(ctx)
|
|
|
|
for _, m := range migrations {
|
|
if err := m(db); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
|
return []migrationFunc{
|
|
func(db *gorm.DB) error {
|
|
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_ip", "account_id", "ip")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_dnslabel", "account_id", "dns_label")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.MigrateJsonToTable[types.Group](ctx, db, "peers", func(accountID, id, value string) any {
|
|
return &types.GroupPeer{
|
|
AccountID: accountID,
|
|
GroupID: id,
|
|
PeerID: value,
|
|
}
|
|
})
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.DropIndex[nbpeer.Peer](ctx, db, "idx_peers_key")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key")
|
|
},
|
|
func(db *gorm.DB) error {
|
|
return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique")
|
|
},
|
|
}
|
|
}
|
|
|
|
// NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env.
|
|
// Optionally it can load a SQL file to the database. If the filename is empty it will return an empty database
|
|
func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (Store, func(), error) {
|
|
kind := getStoreEngineFromEnv()
|
|
if kind == "" {
|
|
kind = types.SqliteStoreEngine
|
|
}
|
|
|
|
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
|
|
if runtime.GOOS == "windows" {
|
|
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
|
|
storeStr = storeSqliteFileName
|
|
}
|
|
|
|
file := filepath.Join(dataDir, storeStr)
|
|
db, err := gorm.Open(sqlite.Open(file), getGormConfig())
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
if filename != "" {
|
|
err = LoadSQL(db, filename)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to load SQL file: %v", err)
|
|
}
|
|
}
|
|
|
|
store, err := NewSqlStore(ctx, db, types.SqliteStoreEngine, nil, false)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to create test store: %v", err)
|
|
}
|
|
|
|
err = addAllGroupToAccount(ctx, store)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to add all group to account: %v", err)
|
|
}
|
|
|
|
var sqlStore Store
|
|
var cleanup func()
|
|
|
|
maxRetries := 2
|
|
for i := 0; i < maxRetries; i++ {
|
|
sqlStore, cleanup, err = getSqlStoreEngine(ctx, store, kind)
|
|
if err == nil {
|
|
return sqlStore, cleanup, nil
|
|
}
|
|
if i < maxRetries-1 {
|
|
time.Sleep(100 * time.Millisecond)
|
|
}
|
|
}
|
|
return nil, nil, fmt.Errorf("failed to create test store after %d attempts: %v", maxRetries, err)
|
|
}
|
|
|
|
func addAllGroupToAccount(ctx context.Context, store Store) error {
|
|
allAccounts := store.GetAllAccounts(ctx)
|
|
for _, account := range allAccounts {
|
|
shouldSave := false
|
|
|
|
_, err := account.GetGroupAll()
|
|
if err != nil {
|
|
if err := account.AddAllGroup(false); err != nil {
|
|
return err
|
|
}
|
|
shouldSave = true
|
|
}
|
|
|
|
if shouldSave {
|
|
err = store.SaveAccount(ctx, account)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine) (Store, func(), error) {
|
|
var cleanup func()
|
|
var err error
|
|
switch kind {
|
|
case types.PostgresStoreEngine:
|
|
store, cleanup, err = newReusedPostgresStore(ctx, store, kind)
|
|
case types.MysqlStoreEngine:
|
|
store, cleanup, err = newReusedMysqlStore(ctx, store, kind)
|
|
default:
|
|
cleanup = func() {
|
|
// sqlite doesn't need to be cleaned up
|
|
}
|
|
}
|
|
if err != nil {
|
|
return nil, cleanup, fmt.Errorf("failed to create test store: %v", err)
|
|
}
|
|
|
|
closeConnection := func() {
|
|
cleanup()
|
|
store.Close(ctx)
|
|
if store.pool != nil {
|
|
store.pool.Close()
|
|
}
|
|
}
|
|
|
|
return store, closeConnection, nil
|
|
}
|
|
|
|
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
|
|
dsn, ok := lookupDSNEnv(postgresDsnEnv, postgresDsnEnvLegacy)
|
|
if !ok || dsn == "" {
|
|
var err error
|
|
_, dsn, err = testutil.CreatePostgresTestContainer()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
}
|
|
|
|
if dsn == "" {
|
|
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
|
|
}
|
|
|
|
db, err := openDBWithRetry(dsn, kind, 5)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err)
|
|
}
|
|
|
|
dsn, cleanup, err := createRandomDB(dsn, db, kind)
|
|
|
|
sqlDB, _ := db.DB()
|
|
if sqlDB != nil {
|
|
sqlDB.Close()
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return store, cleanup, nil
|
|
}
|
|
|
|
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
|
|
dsn, ok := lookupDSNEnv(mysqlDsnEnv, mysqlDsnEnvLegacy)
|
|
if !ok || dsn == "" {
|
|
var err error
|
|
_, dsn, err = testutil.CreateMysqlTestContainer()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
}
|
|
|
|
if dsn == "" {
|
|
return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
|
|
}
|
|
|
|
db, err := openDBWithRetry(dsn, kind, 5)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err)
|
|
}
|
|
|
|
sqlDB, err := db.DB()
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to get underlying sql.DB: %v", err)
|
|
}
|
|
sqlDB.SetMaxOpenConns(1)
|
|
sqlDB.SetMaxIdleConns(1)
|
|
|
|
dsn, cleanup, err := createRandomDB(dsn, db, kind)
|
|
|
|
sqlDB.Close()
|
|
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
store, err = NewMysqlStoreFromSqlStore(ctx, store, dsn, nil)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return store, cleanup, nil
|
|
}
|
|
|
|
func openDBWithRetry(dsn string, engine types.Engine, maxRetries int) (*gorm.DB, error) {
|
|
var db *gorm.DB
|
|
var err error
|
|
|
|
for i := range maxRetries {
|
|
switch engine {
|
|
case types.PostgresStoreEngine:
|
|
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
|
case types.MysqlStoreEngine:
|
|
db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
|
|
}
|
|
|
|
if err == nil {
|
|
return db, nil
|
|
}
|
|
|
|
if i < maxRetries-1 {
|
|
waitTime := time.Duration(100*(i+1)) * time.Millisecond
|
|
time.Sleep(waitTime)
|
|
}
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(), error) {
|
|
dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_"))
|
|
|
|
if err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)).Error; err != nil {
|
|
return "", nil, fmt.Errorf("failed to create database: %v", err)
|
|
}
|
|
|
|
originalDSN := dsn
|
|
|
|
cleanup := func() {
|
|
var dropDB *gorm.DB
|
|
var err error
|
|
|
|
switch engine {
|
|
case types.PostgresStoreEngine:
|
|
dropDB, err = gorm.Open(postgres.Open(originalDSN), &gorm.Config{
|
|
SkipDefaultTransaction: true,
|
|
PrepareStmt: false,
|
|
})
|
|
if err != nil {
|
|
log.Errorf("failed to connect for dropping database %s: %v", dbName, err)
|
|
return
|
|
}
|
|
defer func() {
|
|
if sqlDB, _ := dropDB.DB(); sqlDB != nil {
|
|
sqlDB.Close()
|
|
}
|
|
}()
|
|
|
|
if sqlDB, _ := dropDB.DB(); sqlDB != nil {
|
|
sqlDB.SetMaxOpenConns(1)
|
|
sqlDB.SetMaxIdleConns(0)
|
|
sqlDB.SetConnMaxLifetime(time.Second)
|
|
}
|
|
|
|
err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName)).Error
|
|
|
|
case types.MysqlStoreEngine:
|
|
dropDB, err = gorm.Open(mysql.Open(originalDSN+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{
|
|
SkipDefaultTransaction: true,
|
|
PrepareStmt: false,
|
|
})
|
|
if err != nil {
|
|
log.Errorf("failed to connect for dropping database %s: %v", dbName, err)
|
|
return
|
|
}
|
|
defer func() {
|
|
if sqlDB, _ := dropDB.DB(); sqlDB != nil {
|
|
sqlDB.Close()
|
|
}
|
|
}()
|
|
|
|
if sqlDB, _ := dropDB.DB(); sqlDB != nil {
|
|
sqlDB.SetMaxOpenConns(1)
|
|
sqlDB.SetMaxIdleConns(0)
|
|
sqlDB.SetConnMaxLifetime(time.Second)
|
|
}
|
|
|
|
err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName)).Error
|
|
}
|
|
|
|
if err != nil {
|
|
log.Errorf("failed to drop database %s: %v", dbName, err)
|
|
}
|
|
}
|
|
|
|
return replaceDBName(dsn, dbName), cleanup, nil
|
|
}
|
|
|
|
func replaceDBName(dsn, newDBName string) string {
|
|
re := regexp.MustCompile(`(?P<pre>[:/@])(?P<dbname>[^/?]+)(?P<post>\?|$)`)
|
|
return re.ReplaceAllString(dsn, `${pre}`+newDBName+`${post}`)
|
|
}
|
|
|
|
func LoadSQL(db *gorm.DB, filepath string) error {
|
|
sqlContent, err := os.ReadFile(filepath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
queries := strings.Split(string(sqlContent), ";")
|
|
|
|
for _, query := range queries {
|
|
query = strings.TrimSpace(query)
|
|
if query != "" {
|
|
err := db.Exec(query).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// MigrateFileStoreToSqlite migrates the file store to the SQLite store.
|
|
func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error {
|
|
fileStorePath := path.Join(dataDir, storeFileName)
|
|
if _, err := os.Stat(fileStorePath); errors.Is(err, os.ErrNotExist) {
|
|
return fmt.Errorf("%s doesn't exist, couldn't continue the operation", fileStorePath)
|
|
}
|
|
|
|
sqlStorePath := path.Join(dataDir, storeSqliteFileName)
|
|
if _, err := os.Stat(sqlStorePath); err == nil {
|
|
return fmt.Errorf("%s already exists, couldn't continue the operation", sqlStorePath)
|
|
}
|
|
|
|
fstore, err := NewFileStore(ctx, dataDir, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed creating file store: %s: %v", dataDir, err)
|
|
}
|
|
|
|
fsStoreAccounts := len(fstore.GetAllAccounts(ctx))
|
|
log.WithContext(ctx).Infof("%d account will be migrated from file store %s to sqlite store %s",
|
|
fsStoreAccounts, fileStorePath, sqlStorePath)
|
|
|
|
store, err := NewSqliteStoreFromFileStore(ctx, fstore, dataDir, nil, true)
|
|
if err != nil {
|
|
return fmt.Errorf("failed creating file store: %s: %v", dataDir, err)
|
|
}
|
|
|
|
sqliteStoreAccounts := len(store.GetAllAccounts(ctx))
|
|
if fsStoreAccounts != sqliteStoreAccounts {
|
|
return fmt.Errorf("failed to migratePreAuto accounts from file to sqlite. Expected accounts: %d, got: %d",
|
|
fsStoreAccounts, sqliteStoreAccounts)
|
|
}
|
|
|
|
return nil
|
|
}
|