add custom store cache

This commit is contained in:
Pascal Fischer
2025-10-17 15:57:40 +02:00
parent 8393bf1b17
commit df101bf071
4 changed files with 720 additions and 39 deletions

View File

@@ -0,0 +1,129 @@
package cache
import (
"context"
"sync"
)
// DualKeyCache provides a caching mechanism where each entry has two keys:
// - Primary key (e.g., objectID): used for accessing and invalidating specific entries
// - Secondary key (e.g., accountID): used for bulk invalidation of all entries with the same secondary key
type DualKeyCache[K1 comparable, K2 comparable, V any] struct {
mu sync.RWMutex
primaryIndex map[K1]V // Primary key -> Value
secondaryIndex map[K2]map[K1]struct{} // Secondary key -> Set of primary keys
reverseLookup map[K1]K2 // Primary key -> Secondary key
}
// NewDualKeyCache creates a new dual-key cache
func NewDualKeyCache[K1 comparable, K2 comparable, V any]() *DualKeyCache[K1, K2, V] {
return &DualKeyCache[K1, K2, V]{
primaryIndex: make(map[K1]V),
secondaryIndex: make(map[K2]map[K1]struct{}),
reverseLookup: make(map[K1]K2),
}
}
// Get retrieves a value from the cache using the primary key
func (c *DualKeyCache[K1, K2, V]) Get(ctx context.Context, primaryKey K1) (V, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.primaryIndex[primaryKey]
return value, ok
}
// Set stores a value in the cache with both primary and secondary keys
func (c *DualKeyCache[K1, K2, V]) Set(ctx context.Context, primaryKey K1, secondaryKey K2, value V) {
c.mu.Lock()
defer c.mu.Unlock()
if oldSecondaryKey, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[oldSecondaryKey]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, oldSecondaryKey)
}
}
}
c.primaryIndex[primaryKey] = value
c.reverseLookup[primaryKey] = secondaryKey
if _, exists := c.secondaryIndex[secondaryKey]; !exists {
c.secondaryIndex[secondaryKey] = make(map[K1]struct{})
}
c.secondaryIndex[secondaryKey][primaryKey] = struct{}{}
}
// InvalidateByPrimaryKey removes an entry using the primary key
func (c *DualKeyCache[K1, K2, V]) InvalidateByPrimaryKey(ctx context.Context, primaryKey K1) {
c.mu.Lock()
defer c.mu.Unlock()
if secondaryKey, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[secondaryKey]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, secondaryKey)
}
}
delete(c.reverseLookup, primaryKey)
}
delete(c.primaryIndex, primaryKey)
}
// InvalidateBySecondaryKey removes all entries with the given secondary key
func (c *DualKeyCache[K1, K2, V]) InvalidateBySecondaryKey(ctx context.Context, secondaryKey K2) {
c.mu.Lock()
defer c.mu.Unlock()
primaryKeys, exists := c.secondaryIndex[secondaryKey]
if !exists {
return
}
for primaryKey := range primaryKeys {
delete(c.primaryIndex, primaryKey)
delete(c.reverseLookup, primaryKey)
}
delete(c.secondaryIndex, secondaryKey)
}
// InvalidateAll removes all entries from the cache
func (c *DualKeyCache[K1, K2, V]) InvalidateAll(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()
c.primaryIndex = make(map[K1]V)
c.secondaryIndex = make(map[K2]map[K1]struct{})
c.reverseLookup = make(map[K1]K2)
}
// Size returns the number of entries in the cache
func (c *DualKeyCache[K1, K2, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.primaryIndex)
}
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
// The loadFunc should return both the value and the secondary key (extracted from the value)
func (c *DualKeyCache[K1, K2, V]) GetOrSet(ctx context.Context, primaryKey K1, loadFunc func() (V, K2, error)) (V, error) {
if value, ok := c.Get(ctx, primaryKey); ok {
return value, nil
}
value, secondaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, value)
return value, nil
}

View File

@@ -0,0 +1,77 @@
package cache
import (
"context"
"sync"
)
// SingleKeyCache provides a simple caching mechanism with a single key
type SingleKeyCache[K comparable, V any] struct {
mu sync.RWMutex
cache map[K]V // Key -> Value
}
// NewSingleKeyCache creates a new single-key cache
func NewSingleKeyCache[K comparable, V any]() *SingleKeyCache[K, V] {
return &SingleKeyCache[K, V]{
cache: make(map[K]V),
}
}
// Get retrieves a value from the cache using the key
func (c *SingleKeyCache[K, V]) Get(ctx context.Context, key K) (V, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.cache[key]
return value, ok
}
// Set stores a value in the cache with the given key
func (c *SingleKeyCache[K, V]) Set(ctx context.Context, key K, value V) {
c.mu.Lock()
defer c.mu.Unlock()
c.cache[key] = value
}
// Invalidate removes an entry using the key
func (c *SingleKeyCache[K, V]) Invalidate(ctx context.Context, key K) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.cache, key)
}
// InvalidateAll removes all entries from the cache
func (c *SingleKeyCache[K, V]) InvalidateAll(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()
c.cache = make(map[K]V)
}
// Size returns the number of entries in the cache
func (c *SingleKeyCache[K, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.cache)
}
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
func (c *SingleKeyCache[K, V]) GetOrSet(ctx context.Context, key K, loadFunc func() (V, error)) (V, error) {
if value, ok := c.Get(ctx, key); ok {
return value, nil
}
value, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, key, value)
return value, nil
}

View File

@@ -0,0 +1,242 @@
package cache
import (
"context"
"sync"
)
// TripleKeyCache provides a caching mechanism where each entry has three keys:
// - Primary key (K1): used for accessing and invalidating specific entries
// - Secondary key (K2): used for bulk invalidation of all entries with the same secondary key
// - Tertiary key (K3): used for bulk invalidation of all entries with the same tertiary key
type TripleKeyCache[K1 comparable, K2 comparable, K3 comparable, V any] struct {
mu sync.RWMutex
primaryIndex map[K1]V // Primary key -> Value
secondaryIndex map[K2]map[K1]struct{} // Secondary key -> Set of primary keys
tertiaryIndex map[K3]map[K1]struct{} // Tertiary key -> Set of primary keys
reverseLookup map[K1]keyPair[K2, K3] // Primary key -> Secondary and Tertiary keys
}
type keyPair[K2 comparable, K3 comparable] struct {
secondary K2
tertiary K3
}
// NewTripleKeyCache creates a new triple-key cache
func NewTripleKeyCache[K1 comparable, K2 comparable, K3 comparable, V any]() *TripleKeyCache[K1, K2, K3, V] {
return &TripleKeyCache[K1, K2, K3, V]{
primaryIndex: make(map[K1]V),
secondaryIndex: make(map[K2]map[K1]struct{}),
tertiaryIndex: make(map[K3]map[K1]struct{}),
reverseLookup: make(map[K1]keyPair[K2, K3]),
}
}
// Get retrieves a value from the cache using the primary key
func (c *TripleKeyCache[K1, K2, K3, V]) Get(ctx context.Context, primaryKey K1) (V, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.primaryIndex[primaryKey]
return value, ok
}
// Set stores a value in the cache with primary, secondary, and tertiary keys
func (c *TripleKeyCache[K1, K2, K3, V]) Set(ctx context.Context, primaryKey K1, secondaryKey K2, tertiaryKey K3, value V) {
c.mu.Lock()
defer c.mu.Unlock()
if oldKeys, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[oldKeys.secondary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, oldKeys.secondary)
}
}
if primaryKeys, ok := c.tertiaryIndex[oldKeys.tertiary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.tertiaryIndex, oldKeys.tertiary)
}
}
}
c.primaryIndex[primaryKey] = value
c.reverseLookup[primaryKey] = keyPair[K2, K3]{
secondary: secondaryKey,
tertiary: tertiaryKey,
}
if _, exists := c.secondaryIndex[secondaryKey]; !exists {
c.secondaryIndex[secondaryKey] = make(map[K1]struct{})
}
c.secondaryIndex[secondaryKey][primaryKey] = struct{}{}
if _, exists := c.tertiaryIndex[tertiaryKey]; !exists {
c.tertiaryIndex[tertiaryKey] = make(map[K1]struct{})
}
c.tertiaryIndex[tertiaryKey][primaryKey] = struct{}{}
}
// InvalidateByPrimaryKey removes an entry using the primary key
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateByPrimaryKey(ctx context.Context, primaryKey K1) {
c.mu.Lock()
defer c.mu.Unlock()
if keys, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[keys.secondary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, keys.secondary)
}
}
if primaryKeys, ok := c.tertiaryIndex[keys.tertiary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.tertiaryIndex, keys.tertiary)
}
}
delete(c.reverseLookup, primaryKey)
}
delete(c.primaryIndex, primaryKey)
}
// InvalidateBySecondaryKey removes all entries with the given secondary key
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateBySecondaryKey(ctx context.Context, secondaryKey K2) {
c.mu.Lock()
defer c.mu.Unlock()
primaryKeys, exists := c.secondaryIndex[secondaryKey]
if !exists {
return
}
for primaryKey := range primaryKeys {
if keys, ok := c.reverseLookup[primaryKey]; ok {
if tertiaryPrimaryKeys, exists := c.tertiaryIndex[keys.tertiary]; exists {
delete(tertiaryPrimaryKeys, primaryKey)
if len(tertiaryPrimaryKeys) == 0 {
delete(c.tertiaryIndex, keys.tertiary)
}
}
}
delete(c.primaryIndex, primaryKey)
delete(c.reverseLookup, primaryKey)
}
delete(c.secondaryIndex, secondaryKey)
}
// InvalidateByTertiaryKey removes all entries with the given tertiary key
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateByTertiaryKey(ctx context.Context, tertiaryKey K3) {
c.mu.Lock()
defer c.mu.Unlock()
primaryKeys, exists := c.tertiaryIndex[tertiaryKey]
if !exists {
return
}
for primaryKey := range primaryKeys {
if keys, ok := c.reverseLookup[primaryKey]; ok {
if secondaryPrimaryKeys, exists := c.secondaryIndex[keys.secondary]; exists {
delete(secondaryPrimaryKeys, primaryKey)
if len(secondaryPrimaryKeys) == 0 {
delete(c.secondaryIndex, keys.secondary)
}
}
}
delete(c.primaryIndex, primaryKey)
delete(c.reverseLookup, primaryKey)
}
delete(c.tertiaryIndex, tertiaryKey)
}
// InvalidateAll removes all entries from the cache
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateAll(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()
c.primaryIndex = make(map[K1]V)
c.secondaryIndex = make(map[K2]map[K1]struct{})
c.tertiaryIndex = make(map[K3]map[K1]struct{})
c.reverseLookup = make(map[K1]keyPair[K2, K3])
}
// Size returns the number of entries in the cache
func (c *TripleKeyCache[K1, K2, K3, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.primaryIndex)
}
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
// The loadFunc should return the value, secondary key, and tertiary key (extracted from the value)
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSet(ctx context.Context, primaryKey K1, loadFunc func() (V, K2, K3, error)) (V, error) {
if value, ok := c.Get(ctx, primaryKey); ok {
return value, nil
}
value, secondaryKey, tertiaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
return value, nil
}
// GetOrSetBySecondaryKey retrieves a value from the cache using the secondary key, or sets it using the provided function if not found
// The loadFunc should return the value, primary key, secondary key, and tertiary key
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSetBySecondaryKey(ctx context.Context, secondaryKey K2, loadFunc func() (V, K1, K3, error)) (V, error) {
c.mu.RLock()
if primaryKeys, exists := c.secondaryIndex[secondaryKey]; exists && len(primaryKeys) > 0 {
for primaryKey := range primaryKeys {
if value, ok := c.primaryIndex[primaryKey]; ok {
c.mu.RUnlock()
return value, nil
}
}
}
c.mu.RUnlock()
value, primaryKey, tertiaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
return value, nil
}
// GetOrSetByTertiaryKey retrieves a value from the cache using the tertiary key, or sets it using the provided function if not found
// The loadFunc should return the value, primary key, secondary key, and tertiary key
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSetByTertiaryKey(ctx context.Context, tertiaryKey K3, loadFunc func() (V, K1, K2, error)) (V, error) {
c.mu.RLock()
if primaryKeys, exists := c.tertiaryIndex[tertiaryKey]; exists && len(primaryKeys) > 0 {
for primaryKey := range primaryKeys {
if value, ok := c.primaryIndex[primaryKey]; ok {
c.mu.RUnlock()
return value, nil
}
}
}
c.mu.RUnlock()
value, primaryKey, secondaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
return value, nil
}

View File

@@ -30,6 +30,7 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store/cache"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
@@ -55,6 +56,17 @@ type SqlStore struct {
metrics telemetry.AppMetrics
installationPK int
storeEngine types.Engine
// Cache for user data: primary key = userID, secondary key = PatID, tertiary key = accountID
userCache *cache.TripleKeyCache[string, string, string, *types.User]
// Cache for account settings: primary key = accountID, secondary key = accountID
settingsCache *cache.SingleKeyCache[string, *types.Settings]
// Cache for peer: primary key = peerKey, secondary key = peerID, tertiary key = accountID
peerCache *cache.TripleKeyCache[string, string, string, *nbpeer.Peer]
// Cache for accountID: primary key = peerKey, secondary key = accountID
accountIDCache *cache.DualKeyCache[string, string, string]
// Cache for domain and category: primary key = accountID
domainDataCache *cache.SingleKeyCache[string, *DomainData]
}
type installation struct {
@@ -62,6 +74,11 @@ type installation struct {
InstallationIDValue string
}
type DomainData struct {
Domain string
Category string
}
type migrationFunc func(*gorm.DB) error
// NewSqlStore creates a new SqlStore instance.
@@ -94,7 +111,17 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
if skipMigration {
log.WithContext(ctx).Infof("skipping migration")
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
return &SqlStore{
db: db,
storeEngine: storeEngine,
metrics: metrics,
installationPK: 1,
userCache: cache.NewTripleKeyCache[string, string, string, *types.User](),
settingsCache: cache.NewSingleKeyCache[string, *types.Settings](),
peerCache: cache.NewTripleKeyCache[string, string, string, *nbpeer.Peer](),
accountIDCache: cache.NewDualKeyCache[string, string, string](),
domainDataCache: cache.NewSingleKeyCache[string, *DomainData](),
}, nil
}
if err := migratePreAuto(ctx, db); err != nil {
@@ -113,7 +140,17 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
return nil, fmt.Errorf("migratePostAuto: %w", err)
}
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
return &SqlStore{
db: db,
storeEngine: storeEngine,
metrics: metrics,
installationPK: 1,
userCache: cache.NewTripleKeyCache[string, string, string, *types.User](),
settingsCache: cache.NewSingleKeyCache[string, *types.Settings](),
peerCache: cache.NewTripleKeyCache[string, string, string, *nbpeer.Peer](),
accountIDCache: cache.NewDualKeyCache[string, string, string](),
domainDataCache: cache.NewSingleKeyCache[string, *DomainData](),
}, nil
}
func GetKeyQueryCondition(s *SqlStore) string {
@@ -191,6 +228,12 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
return nil
})
s.userCache.InvalidateBySecondaryKey(ctx, account.Id)
s.settingsCache.Invalidate(ctx, account.Id)
s.accountIDCache.InvalidateByPrimaryKey(ctx, account.Id)
s.peerCache.InvalidateByTertiaryKey(ctx, account.Id)
s.domainDataCache.Invalidate(ctx, account.Id)
took := time.Since(start)
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
@@ -282,6 +325,12 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er
return nil
})
s.userCache.InvalidateBySecondaryKey(ctx, account.Id)
s.settingsCache.Invalidate(ctx, account.Id)
s.accountIDCache.InvalidateByPrimaryKey(ctx, account.Id)
s.peerCache.InvalidateByTertiaryKey(ctx, account.Id)
s.domainDataCache.Invalidate(ctx, account.Id)
took := time.Since(start)
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
@@ -347,6 +396,8 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
return nil
})
s.peerCache.InvalidateByPrimaryKey(ctx, peer.ID)
if err != nil {
return err
}
@@ -374,6 +425,8 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
return err
}
s.domainDataCache.Invalidate(ctx, accountID)
if result.RowsAffected == 0 {
err = status.Errorf(status.NotFound, "account %s", accountID)
return err
@@ -402,6 +455,8 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string,
return err
}
s.peerCache.InvalidateByPrimaryKey(ctx, peerID)
if result.RowsAffected == 0 {
err = status.Errorf(status.NotFound, peerNotFoundFMT, peerID)
return err
@@ -429,6 +484,8 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerW
return err
}
s.peerCache.InvalidateByPrimaryKey(ctx, peerWithLocation.ID)
if result.RowsAffected == 0 {
err = status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID)
return err
@@ -452,6 +509,11 @@ func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
err = status.Errorf(status.Internal, "failed to save users to store")
return err
}
for _, user := range users {
s.userCache.InvalidateByPrimaryKey(ctx, user.Id)
}
return nil
}
@@ -466,6 +528,9 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
err = status.Errorf(status.Internal, "failed to save user to store")
return err
}
s.userCache.InvalidateByPrimaryKey(ctx, user.Id)
return nil
}
@@ -629,6 +694,31 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
var err error
defer s.trackStoreOperation(time.Now(), "GetUserByPATID", &err)()
// If no locking is required, try to get from cache first
if lockStrength == LockingStrengthNone && s.userCache != nil {
user, cacheErr := s.userCache.GetOrSetBySecondaryKey(ctx, patID, func() (*types.User, string, string, error) {
return s.getUserByPATIDFromDB(ctx, lockStrength, patID)
})
if cacheErr != nil {
err = cacheErr
return nil, err
}
return user, nil
}
// If locking is required, bypass cache and fetch directly from DB
user, _, _, dbErr := s.getUserByPATIDFromDB(ctx, lockStrength, patID)
if dbErr != nil {
err = dbErr
return nil, err
}
return user, nil
}
func (s *SqlStore) getUserByPATIDFromDB(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, string, string, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
@@ -640,21 +730,43 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
Where("personal_access_tokens.id = ?", patID).Take(&user)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
err = status.NewPATNotFoundError(patID)
return nil, err
err := status.NewPATNotFoundError(patID)
return nil, "", "", err
}
log.WithContext(ctx).Errorf("failed to get token user from the store: %s", result.Error)
err = status.NewGetUserFromStoreError()
return nil, err
err := status.NewGetUserFromStoreError()
return nil, "", "", err
}
return &user, nil
return &user, patID, user.AccountID, nil
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
var err error
defer s.trackStoreOperation(time.Now(), "GetUserByUserID", &err)()
// If no locking is required, try to get from cache first
if lockStrength == LockingStrengthNone && s.userCache != nil {
user, cacheErr := s.userCache.GetOrSet(ctx, userID, func() (*types.User, string, string, error) {
return s.getUserByUserIDFromDB(ctx, lockStrength, userID)
})
if cacheErr != nil {
err = cacheErr
return nil, err
}
return user, nil
}
// If locking is required, bypass cache and fetch directly from DB
user, _, _, dbErr := s.getUserByUserIDFromDB(ctx, lockStrength, userID)
if dbErr != nil {
err = dbErr
return nil, err
}
return user, nil
}
func (s *SqlStore) getUserByUserIDFromDB(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, string, string, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
@@ -667,14 +779,13 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
err = status.NewUserNotFoundError(userID)
return nil, err
return nil, "", "", status.NewUserNotFoundError(userID)
}
err = status.NewGetUserFromStoreError()
return nil, err
return nil, "", "", status.NewGetUserFromStoreError()
}
return &user, nil
// Return user, accountID (secondary key), and no error
return &user, "", user.AccountID, nil
}
func (s *SqlStore) DeleteUser(ctx context.Context, accountID, userID string) error {
@@ -695,6 +806,8 @@ func (s *SqlStore) DeleteUser(ctx context.Context, accountID, userID string) err
return err
}
s.userCache.InvalidateByPrimaryKey(ctx, userID)
return nil
}
@@ -1085,16 +1198,40 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
var err error
defer s.trackStoreOperation(time.Now(), "GetAccountIDByPeerPubKey", &err)()
// Try to get from cache first
if s.accountIDCache != nil {
accountID, cacheErr := s.accountIDCache.GetOrSet(ctx, peerKey, func() (string, string, error) {
accountId, err := s.getAccountIDByPeerPubKeyFromDB(ctx, peerKey)
if err != nil {
return "", "", err
}
return accountId, accountId, nil
})
if cacheErr != nil {
err = cacheErr
return "", err
}
return accountID, nil
}
// Fallback to direct DB query if cache is not available
accountID, dbErr := s.getAccountIDByPeerPubKeyFromDB(ctx, peerKey)
if dbErr != nil {
err = dbErr
return "", err
}
return accountID, nil
}
func (s *SqlStore) getAccountIDByPeerPubKeyFromDB(ctx context.Context, peerKey string) (string, error) {
var peer nbpeer.Peer
var accountID string
result := s.db.Model(&peer).Select("account_id").Where(GetKeyQueryCondition(s), peerKey).Take(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
err = status.Errorf(status.NotFound, "account not found: index lookup failed")
return "", err
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
err = status.NewGetAccountFromStoreError(result.Error)
return "", err
return "", status.NewGetAccountFromStoreError(result.Error)
}
return accountID, nil
@@ -1265,6 +1402,28 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
var err error
defer s.trackStoreOperation(time.Now(), "GetPeerByPeerPubKey", &err)()
// If no locking is required, try to get from cache first
if lockStrength == LockingStrengthNone && s.peerCache != nil {
peer, cacheErr := s.peerCache.GetOrSet(ctx, peerKey, func() (*nbpeer.Peer, string, string, error) {
return s.getPeerByPeerPubKeyFromDB(ctx, lockStrength, peerKey)
})
if cacheErr != nil {
err = cacheErr
return nil, err
}
return peer, nil
}
// If locking is required, bypass cache and fetch directly from DB
peer, _, _, dbErr := s.getPeerByPeerPubKeyFromDB(ctx, lockStrength, peerKey)
if dbErr != nil {
err = dbErr
return nil, err
}
return peer, nil
}
func (s *SqlStore) getPeerByPeerPubKeyFromDB(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, string, string, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
@@ -1278,33 +1437,51 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
err = status.NewPeerNotFoundError(peerKey)
return nil, err
return nil, "", "", status.NewPeerNotFoundError(peerKey)
}
err = status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error)
return nil, err
return nil, "", "", status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error)
}
return &peer, nil
return &peer, peer.ID, peer.AccountID, nil
}
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) {
var err error
defer s.trackStoreOperation(time.Now(), "GetAccountSettings", &err)()
// If no locking is required, try to get from cache first
if lockStrength == LockingStrengthNone && s.settingsCache != nil {
settings, cacheErr := s.settingsCache.GetOrSet(ctx, accountID, func() (*types.Settings, error) {
return s.getAccountSettingsFromDB(ctx, lockStrength, accountID)
})
if cacheErr != nil {
err = cacheErr
return nil, err
}
return settings, nil
}
// If locking is required, bypass cache and fetch directly from DB
settings, dbErr := s.getAccountSettingsFromDB(ctx, lockStrength, accountID)
if dbErr != nil {
err = dbErr
return nil, err
}
return settings, nil
}
func (s *SqlStore) getAccountSettingsFromDB(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountSettings types.AccountSettings
if err = tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountSettings).Error; err != nil {
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
err = status.Errorf(status.NotFound, "settings not found")
return nil, err
return nil, status.Errorf(status.NotFound, "settings not found")
}
err = status.Errorf(status.Internal, "issue getting settings from store: %s", err)
return nil, err
return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err)
}
return accountSettings.Settings, nil
}
@@ -1358,6 +1535,8 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri
return err
}
s.userCache.InvalidateByPrimaryKey(ctx, userID)
return nil
}
@@ -1881,6 +2060,28 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength
var err error
defer s.trackStoreOperation(time.Now(), "GetPeerByID", &err)()
// If no locking is required, try to get from cache first
if lockStrength == LockingStrengthNone && s.peerCache != nil {
peer, cacheErr := s.peerCache.GetOrSetBySecondaryKey(ctx, peerID, func() (*nbpeer.Peer, string, string, error) {
return s.getPeerByIDFromDB(ctx, lockStrength, accountID, peerID)
})
if cacheErr != nil {
err = cacheErr
return nil, err
}
return peer, nil
}
// If locking is required, bypass cache and fetch directly from DB
peer, _, _, dbErr := s.getPeerByIDFromDB(ctx, lockStrength, accountID, peerID)
if dbErr != nil {
err = dbErr
return nil, err
}
return peer, nil
}
func (s *SqlStore) getPeerByIDFromDB(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, string, string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
@@ -1891,14 +2092,12 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength
Take(&peer, accountAndIDQueryCondition, accountID, peerID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
err = status.NewPeerNotFoundError(peerID)
return nil, err
return nil, "", "", status.NewPeerNotFoundError(peerID)
}
err = status.Errorf(status.Internal, "failed to get peer from store")
return nil, err
return nil, "", "", status.Errorf(status.Internal, "failed to get peer from store")
}
return peer, nil
return peer, peer.Key, peer.AccountID, nil
}
// GetPeersByIDs retrieves peers by their IDs and account ID.
@@ -2017,6 +2216,9 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri
return err
}
s.peerCache.InvalidateByPrimaryKey(ctx, peerID)
s.accountIDCache.InvalidateBySecondaryKey(ctx, accountID)
return nil
}
@@ -2061,9 +2263,14 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
func (s *SqlStore) withTx(tx *gorm.DB) Store {
return &SqlStore{
db: tx,
metrics: s.metrics,
storeEngine: s.storeEngine,
db: tx,
metrics: s.metrics,
storeEngine: s.storeEngine,
peerCache: s.peerCache,
userCache: s.userCache,
accountIDCache: s.accountIDCache,
domainDataCache: s.domainDataCache,
settingsCache: s.settingsCache,
}
}
@@ -2124,6 +2331,28 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
var err error
defer s.trackStoreOperation(time.Now(), "GetAccountDomainAndCategory", &err)()
// If no locking is required, try to get from cache first
if lockStrength == LockingStrengthNone && s.domainDataCache != nil {
domainData, cacheErr := s.domainDataCache.GetOrSet(ctx, accountID, func() (*DomainData, error) {
return s.getAccountDomainAndCategoryFromDB(ctx, lockStrength, accountID)
})
if cacheErr != nil {
err = cacheErr
return "", "", err
}
return domainData.Domain, domainData.Category, nil
}
// If locking is required, bypass cache and fetch directly from DB
domainData, dbErr := s.getAccountDomainAndCategoryFromDB(ctx, lockStrength, accountID)
if dbErr != nil {
err = dbErr
return "", "", err
}
return domainData.Domain, domainData.Category, nil
}
func (s *SqlStore) getAccountDomainAndCategoryFromDB(ctx context.Context, lockStrength LockingStrength, accountID string) (*DomainData, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
@@ -2134,14 +2363,15 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
Where(idQueryCondition, accountID).Take(&account)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
err = status.Errorf(status.NotFound, "account not found")
return "", "", err
return nil, status.Errorf(status.NotFound, "account not found")
}
err = status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
return "", "", err
return nil, status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
}
return account.Domain, account.DomainCategory, nil
return &DomainData{
Domain: account.Domain,
Category: account.DomainCategory,
}, nil
}
// GetGroupByID retrieves a group by ID and account ID.
@@ -2822,6 +3052,8 @@ func (s *SqlStore) SaveAccountSettings(ctx context.Context, accountID string, se
return err
}
s.settingsCache.Invalidate(ctx, accountID)
if result.RowsAffected == 0 {
err = status.NewAccountNotFoundError(accountID)
return err
@@ -3421,6 +3653,7 @@ func (s *SqlStore) MarkAccountPrimary(ctx context.Context, accountID string) err
err = status.Errorf(status.Internal, "failed to mark account as primary")
return err
}
s.domainDataCache.Invalidate(ctx, accountID)
if result.RowsAffected == 0 {
err = status.NewAccountNotFoundError(accountID)