diff --git a/management/server/store/cache/dual_key_cache.go b/management/server/store/cache/dual_key_cache.go new file mode 100644 index 000000000..1419064f9 --- /dev/null +++ b/management/server/store/cache/dual_key_cache.go @@ -0,0 +1,129 @@ +package cache + +import ( + "context" + "sync" +) + +// DualKeyCache provides a caching mechanism where each entry has two keys: +// - Primary key (e.g., objectID): used for accessing and invalidating specific entries +// - Secondary key (e.g., accountID): used for bulk invalidation of all entries with the same secondary key +type DualKeyCache[K1 comparable, K2 comparable, V any] struct { + mu sync.RWMutex + primaryIndex map[K1]V // Primary key -> Value + secondaryIndex map[K2]map[K1]struct{} // Secondary key -> Set of primary keys + reverseLookup map[K1]K2 // Primary key -> Secondary key +} + +// NewDualKeyCache creates a new dual-key cache +func NewDualKeyCache[K1 comparable, K2 comparable, V any]() *DualKeyCache[K1, K2, V] { + return &DualKeyCache[K1, K2, V]{ + primaryIndex: make(map[K1]V), + secondaryIndex: make(map[K2]map[K1]struct{}), + reverseLookup: make(map[K1]K2), + } +} + +// Get retrieves a value from the cache using the primary key +func (c *DualKeyCache[K1, K2, V]) Get(ctx context.Context, primaryKey K1) (V, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + value, ok := c.primaryIndex[primaryKey] + return value, ok +} + +// Set stores a value in the cache with both primary and secondary keys +func (c *DualKeyCache[K1, K2, V]) Set(ctx context.Context, primaryKey K1, secondaryKey K2, value V) { + c.mu.Lock() + defer c.mu.Unlock() + + if oldSecondaryKey, exists := c.reverseLookup[primaryKey]; exists { + if primaryKeys, ok := c.secondaryIndex[oldSecondaryKey]; ok { + delete(primaryKeys, primaryKey) + if len(primaryKeys) == 0 { + delete(c.secondaryIndex, oldSecondaryKey) + } + } + } + + c.primaryIndex[primaryKey] = value + c.reverseLookup[primaryKey] = secondaryKey + + if _, exists := c.secondaryIndex[secondaryKey]; !exists { + c.secondaryIndex[secondaryKey] = make(map[K1]struct{}) + } + c.secondaryIndex[secondaryKey][primaryKey] = struct{}{} +} + +// InvalidateByPrimaryKey removes an entry using the primary key +func (c *DualKeyCache[K1, K2, V]) InvalidateByPrimaryKey(ctx context.Context, primaryKey K1) { + c.mu.Lock() + defer c.mu.Unlock() + + if secondaryKey, exists := c.reverseLookup[primaryKey]; exists { + if primaryKeys, ok := c.secondaryIndex[secondaryKey]; ok { + delete(primaryKeys, primaryKey) + if len(primaryKeys) == 0 { + delete(c.secondaryIndex, secondaryKey) + } + } + delete(c.reverseLookup, primaryKey) + } + + delete(c.primaryIndex, primaryKey) +} + +// InvalidateBySecondaryKey removes all entries with the given secondary key +func (c *DualKeyCache[K1, K2, V]) InvalidateBySecondaryKey(ctx context.Context, secondaryKey K2) { + c.mu.Lock() + defer c.mu.Unlock() + + primaryKeys, exists := c.secondaryIndex[secondaryKey] + if !exists { + return + } + + for primaryKey := range primaryKeys { + delete(c.primaryIndex, primaryKey) + delete(c.reverseLookup, primaryKey) + } + + delete(c.secondaryIndex, secondaryKey) +} + +// InvalidateAll removes all entries from the cache +func (c *DualKeyCache[K1, K2, V]) InvalidateAll(ctx context.Context) { + c.mu.Lock() + defer c.mu.Unlock() + + c.primaryIndex = make(map[K1]V) + c.secondaryIndex = make(map[K2]map[K1]struct{}) + c.reverseLookup = make(map[K1]K2) +} + +// Size returns the number of entries in the cache +func (c *DualKeyCache[K1, K2, V]) Size() int { + c.mu.RLock() + defer c.mu.RUnlock() + + return len(c.primaryIndex) +} + +// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found +// The loadFunc should return both the value and the secondary key (extracted from the value) +func (c *DualKeyCache[K1, K2, V]) GetOrSet(ctx context.Context, primaryKey K1, loadFunc func() (V, K2, error)) (V, error) { + if value, ok := c.Get(ctx, primaryKey); ok { + return value, nil + } + + value, secondaryKey, err := loadFunc() + if err != nil { + var zero V + return zero, err + } + + c.Set(ctx, primaryKey, secondaryKey, value) + + return value, nil +} diff --git a/management/server/store/cache/single_key_cache.go b/management/server/store/cache/single_key_cache.go new file mode 100644 index 000000000..218cfe7be --- /dev/null +++ b/management/server/store/cache/single_key_cache.go @@ -0,0 +1,77 @@ +package cache + +import ( + "context" + "sync" +) + +// SingleKeyCache provides a simple caching mechanism with a single key +type SingleKeyCache[K comparable, V any] struct { + mu sync.RWMutex + cache map[K]V // Key -> Value +} + +// NewSingleKeyCache creates a new single-key cache +func NewSingleKeyCache[K comparable, V any]() *SingleKeyCache[K, V] { + return &SingleKeyCache[K, V]{ + cache: make(map[K]V), + } +} + +// Get retrieves a value from the cache using the key +func (c *SingleKeyCache[K, V]) Get(ctx context.Context, key K) (V, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + value, ok := c.cache[key] + return value, ok +} + +// Set stores a value in the cache with the given key +func (c *SingleKeyCache[K, V]) Set(ctx context.Context, key K, value V) { + c.mu.Lock() + defer c.mu.Unlock() + + c.cache[key] = value +} + +// Invalidate removes an entry using the key +func (c *SingleKeyCache[K, V]) Invalidate(ctx context.Context, key K) { + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.cache, key) +} + +// InvalidateAll removes all entries from the cache +func (c *SingleKeyCache[K, V]) InvalidateAll(ctx context.Context) { + c.mu.Lock() + defer c.mu.Unlock() + + c.cache = make(map[K]V) +} + +// Size returns the number of entries in the cache +func (c *SingleKeyCache[K, V]) Size() int { + c.mu.RLock() + defer c.mu.RUnlock() + + return len(c.cache) +} + +// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found +func (c *SingleKeyCache[K, V]) GetOrSet(ctx context.Context, key K, loadFunc func() (V, error)) (V, error) { + if value, ok := c.Get(ctx, key); ok { + return value, nil + } + + value, err := loadFunc() + if err != nil { + var zero V + return zero, err + } + + c.Set(ctx, key, value) + + return value, nil +} diff --git a/management/server/store/cache/triple_key_cache.go b/management/server/store/cache/triple_key_cache.go new file mode 100644 index 000000000..88990694d --- /dev/null +++ b/management/server/store/cache/triple_key_cache.go @@ -0,0 +1,242 @@ +package cache + +import ( + "context" + "sync" +) + +// TripleKeyCache provides a caching mechanism where each entry has three keys: +// - Primary key (K1): used for accessing and invalidating specific entries +// - Secondary key (K2): used for bulk invalidation of all entries with the same secondary key +// - Tertiary key (K3): used for bulk invalidation of all entries with the same tertiary key +type TripleKeyCache[K1 comparable, K2 comparable, K3 comparable, V any] struct { + mu sync.RWMutex + primaryIndex map[K1]V // Primary key -> Value + secondaryIndex map[K2]map[K1]struct{} // Secondary key -> Set of primary keys + tertiaryIndex map[K3]map[K1]struct{} // Tertiary key -> Set of primary keys + reverseLookup map[K1]keyPair[K2, K3] // Primary key -> Secondary and Tertiary keys +} + +type keyPair[K2 comparable, K3 comparable] struct { + secondary K2 + tertiary K3 +} + +// NewTripleKeyCache creates a new triple-key cache +func NewTripleKeyCache[K1 comparable, K2 comparable, K3 comparable, V any]() *TripleKeyCache[K1, K2, K3, V] { + return &TripleKeyCache[K1, K2, K3, V]{ + primaryIndex: make(map[K1]V), + secondaryIndex: make(map[K2]map[K1]struct{}), + tertiaryIndex: make(map[K3]map[K1]struct{}), + reverseLookup: make(map[K1]keyPair[K2, K3]), + } +} + +// Get retrieves a value from the cache using the primary key +func (c *TripleKeyCache[K1, K2, K3, V]) Get(ctx context.Context, primaryKey K1) (V, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + value, ok := c.primaryIndex[primaryKey] + return value, ok +} + +// Set stores a value in the cache with primary, secondary, and tertiary keys +func (c *TripleKeyCache[K1, K2, K3, V]) Set(ctx context.Context, primaryKey K1, secondaryKey K2, tertiaryKey K3, value V) { + c.mu.Lock() + defer c.mu.Unlock() + + if oldKeys, exists := c.reverseLookup[primaryKey]; exists { + if primaryKeys, ok := c.secondaryIndex[oldKeys.secondary]; ok { + delete(primaryKeys, primaryKey) + if len(primaryKeys) == 0 { + delete(c.secondaryIndex, oldKeys.secondary) + } + } + if primaryKeys, ok := c.tertiaryIndex[oldKeys.tertiary]; ok { + delete(primaryKeys, primaryKey) + if len(primaryKeys) == 0 { + delete(c.tertiaryIndex, oldKeys.tertiary) + } + } + } + + c.primaryIndex[primaryKey] = value + c.reverseLookup[primaryKey] = keyPair[K2, K3]{ + secondary: secondaryKey, + tertiary: tertiaryKey, + } + + if _, exists := c.secondaryIndex[secondaryKey]; !exists { + c.secondaryIndex[secondaryKey] = make(map[K1]struct{}) + } + c.secondaryIndex[secondaryKey][primaryKey] = struct{}{} + + if _, exists := c.tertiaryIndex[tertiaryKey]; !exists { + c.tertiaryIndex[tertiaryKey] = make(map[K1]struct{}) + } + c.tertiaryIndex[tertiaryKey][primaryKey] = struct{}{} +} + +// InvalidateByPrimaryKey removes an entry using the primary key +func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateByPrimaryKey(ctx context.Context, primaryKey K1) { + c.mu.Lock() + defer c.mu.Unlock() + + if keys, exists := c.reverseLookup[primaryKey]; exists { + if primaryKeys, ok := c.secondaryIndex[keys.secondary]; ok { + delete(primaryKeys, primaryKey) + if len(primaryKeys) == 0 { + delete(c.secondaryIndex, keys.secondary) + } + } + if primaryKeys, ok := c.tertiaryIndex[keys.tertiary]; ok { + delete(primaryKeys, primaryKey) + if len(primaryKeys) == 0 { + delete(c.tertiaryIndex, keys.tertiary) + } + } + delete(c.reverseLookup, primaryKey) + } + + delete(c.primaryIndex, primaryKey) +} + +// InvalidateBySecondaryKey removes all entries with the given secondary key +func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateBySecondaryKey(ctx context.Context, secondaryKey K2) { + c.mu.Lock() + defer c.mu.Unlock() + + primaryKeys, exists := c.secondaryIndex[secondaryKey] + if !exists { + return + } + + for primaryKey := range primaryKeys { + if keys, ok := c.reverseLookup[primaryKey]; ok { + if tertiaryPrimaryKeys, exists := c.tertiaryIndex[keys.tertiary]; exists { + delete(tertiaryPrimaryKeys, primaryKey) + if len(tertiaryPrimaryKeys) == 0 { + delete(c.tertiaryIndex, keys.tertiary) + } + } + } + delete(c.primaryIndex, primaryKey) + delete(c.reverseLookup, primaryKey) + } + + delete(c.secondaryIndex, secondaryKey) +} + +// InvalidateByTertiaryKey removes all entries with the given tertiary key +func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateByTertiaryKey(ctx context.Context, tertiaryKey K3) { + c.mu.Lock() + defer c.mu.Unlock() + + primaryKeys, exists := c.tertiaryIndex[tertiaryKey] + if !exists { + return + } + + for primaryKey := range primaryKeys { + if keys, ok := c.reverseLookup[primaryKey]; ok { + if secondaryPrimaryKeys, exists := c.secondaryIndex[keys.secondary]; exists { + delete(secondaryPrimaryKeys, primaryKey) + if len(secondaryPrimaryKeys) == 0 { + delete(c.secondaryIndex, keys.secondary) + } + } + } + delete(c.primaryIndex, primaryKey) + delete(c.reverseLookup, primaryKey) + } + + delete(c.tertiaryIndex, tertiaryKey) +} + +// InvalidateAll removes all entries from the cache +func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateAll(ctx context.Context) { + c.mu.Lock() + defer c.mu.Unlock() + + c.primaryIndex = make(map[K1]V) + c.secondaryIndex = make(map[K2]map[K1]struct{}) + c.tertiaryIndex = make(map[K3]map[K1]struct{}) + c.reverseLookup = make(map[K1]keyPair[K2, K3]) +} + +// Size returns the number of entries in the cache +func (c *TripleKeyCache[K1, K2, K3, V]) Size() int { + c.mu.RLock() + defer c.mu.RUnlock() + + return len(c.primaryIndex) +} + +// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found +// The loadFunc should return the value, secondary key, and tertiary key (extracted from the value) +func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSet(ctx context.Context, primaryKey K1, loadFunc func() (V, K2, K3, error)) (V, error) { + if value, ok := c.Get(ctx, primaryKey); ok { + return value, nil + } + + value, secondaryKey, tertiaryKey, err := loadFunc() + if err != nil { + var zero V + return zero, err + } + + c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value) + + return value, nil +} + +// GetOrSetBySecondaryKey retrieves a value from the cache using the secondary key, or sets it using the provided function if not found +// The loadFunc should return the value, primary key, secondary key, and tertiary key +func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSetBySecondaryKey(ctx context.Context, secondaryKey K2, loadFunc func() (V, K1, K3, error)) (V, error) { + c.mu.RLock() + if primaryKeys, exists := c.secondaryIndex[secondaryKey]; exists && len(primaryKeys) > 0 { + for primaryKey := range primaryKeys { + if value, ok := c.primaryIndex[primaryKey]; ok { + c.mu.RUnlock() + return value, nil + } + } + } + c.mu.RUnlock() + + value, primaryKey, tertiaryKey, err := loadFunc() + if err != nil { + var zero V + return zero, err + } + + c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value) + + return value, nil +} + +// GetOrSetByTertiaryKey retrieves a value from the cache using the tertiary key, or sets it using the provided function if not found +// The loadFunc should return the value, primary key, secondary key, and tertiary key +func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSetByTertiaryKey(ctx context.Context, tertiaryKey K3, loadFunc func() (V, K1, K2, error)) (V, error) { + c.mu.RLock() + if primaryKeys, exists := c.tertiaryIndex[tertiaryKey]; exists && len(primaryKeys) > 0 { + for primaryKey := range primaryKeys { + if value, ok := c.primaryIndex[primaryKey]; ok { + c.mu.RUnlock() + return value, nil + } + } + } + c.mu.RUnlock() + + value, primaryKey, secondaryKey, err := loadFunc() + if err != nil { + var zero V + return zero, err + } + + c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value) + + return value, nil +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 731dd857d..94255bf5b 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -30,6 +30,7 @@ import ( networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store/cache" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" @@ -55,6 +56,17 @@ type SqlStore struct { metrics telemetry.AppMetrics installationPK int storeEngine types.Engine + + // Cache for user data: primary key = userID, secondary key = PatID, tertiary key = accountID + userCache *cache.TripleKeyCache[string, string, string, *types.User] + // Cache for account settings: primary key = accountID, secondary key = accountID + settingsCache *cache.SingleKeyCache[string, *types.Settings] + // Cache for peer: primary key = peerKey, secondary key = peerID, tertiary key = accountID + peerCache *cache.TripleKeyCache[string, string, string, *nbpeer.Peer] + // Cache for accountID: primary key = peerKey, secondary key = accountID + accountIDCache *cache.DualKeyCache[string, string, string] + // Cache for domain and category: primary key = accountID + domainDataCache *cache.SingleKeyCache[string, *DomainData] } type installation struct { @@ -62,6 +74,11 @@ type installation struct { InstallationIDValue string } +type DomainData struct { + Domain string + Category string +} + type migrationFunc func(*gorm.DB) error // NewSqlStore creates a new SqlStore instance. @@ -94,7 +111,17 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met if skipMigration { log.WithContext(ctx).Infof("skipping migration") - return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil + return &SqlStore{ + db: db, + storeEngine: storeEngine, + metrics: metrics, + installationPK: 1, + userCache: cache.NewTripleKeyCache[string, string, string, *types.User](), + settingsCache: cache.NewSingleKeyCache[string, *types.Settings](), + peerCache: cache.NewTripleKeyCache[string, string, string, *nbpeer.Peer](), + accountIDCache: cache.NewDualKeyCache[string, string, string](), + domainDataCache: cache.NewSingleKeyCache[string, *DomainData](), + }, nil } if err := migratePreAuto(ctx, db); err != nil { @@ -113,7 +140,17 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met return nil, fmt.Errorf("migratePostAuto: %w", err) } - return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil + return &SqlStore{ + db: db, + storeEngine: storeEngine, + metrics: metrics, + installationPK: 1, + userCache: cache.NewTripleKeyCache[string, string, string, *types.User](), + settingsCache: cache.NewSingleKeyCache[string, *types.Settings](), + peerCache: cache.NewTripleKeyCache[string, string, string, *nbpeer.Peer](), + accountIDCache: cache.NewDualKeyCache[string, string, string](), + domainDataCache: cache.NewSingleKeyCache[string, *DomainData](), + }, nil } func GetKeyQueryCondition(s *SqlStore) string { @@ -191,6 +228,12 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro return nil }) + s.userCache.InvalidateBySecondaryKey(ctx, account.Id) + s.settingsCache.Invalidate(ctx, account.Id) + s.accountIDCache.InvalidateByPrimaryKey(ctx, account.Id) + s.peerCache.InvalidateByTertiaryKey(ctx, account.Id) + s.domainDataCache.Invalidate(ctx, account.Id) + took := time.Since(start) if s.metrics != nil { s.metrics.StoreMetrics().CountPersistenceDuration(took) @@ -282,6 +325,12 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er return nil }) + s.userCache.InvalidateBySecondaryKey(ctx, account.Id) + s.settingsCache.Invalidate(ctx, account.Id) + s.accountIDCache.InvalidateByPrimaryKey(ctx, account.Id) + s.peerCache.InvalidateByTertiaryKey(ctx, account.Id) + s.domainDataCache.Invalidate(ctx, account.Id) + took := time.Since(start) if s.metrics != nil { s.metrics.StoreMetrics().CountPersistenceDuration(took) @@ -347,6 +396,8 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. return nil }) + s.peerCache.InvalidateByPrimaryKey(ctx, peer.ID) + if err != nil { return err } @@ -374,6 +425,8 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID return err } + s.domainDataCache.Invalidate(ctx, accountID) + if result.RowsAffected == 0 { err = status.Errorf(status.NotFound, "account %s", accountID) return err @@ -402,6 +455,8 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string, return err } + s.peerCache.InvalidateByPrimaryKey(ctx, peerID) + if result.RowsAffected == 0 { err = status.Errorf(status.NotFound, peerNotFoundFMT, peerID) return err @@ -429,6 +484,8 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerW return err } + s.peerCache.InvalidateByPrimaryKey(ctx, peerWithLocation.ID) + if result.RowsAffected == 0 { err = status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID) return err @@ -452,6 +509,11 @@ func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error { err = status.Errorf(status.Internal, "failed to save users to store") return err } + + for _, user := range users { + s.userCache.InvalidateByPrimaryKey(ctx, user.Id) + } + return nil } @@ -466,6 +528,9 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error { err = status.Errorf(status.Internal, "failed to save user to store") return err } + + s.userCache.InvalidateByPrimaryKey(ctx, user.Id) + return nil } @@ -629,6 +694,31 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren var err error defer s.trackStoreOperation(time.Now(), "GetUserByPATID", &err)() + // If no locking is required, try to get from cache first + if lockStrength == LockingStrengthNone && s.userCache != nil { + user, cacheErr := s.userCache.GetOrSetBySecondaryKey(ctx, patID, func() (*types.User, string, string, error) { + return s.getUserByPATIDFromDB(ctx, lockStrength, patID) + }) + if cacheErr != nil { + err = cacheErr + return nil, err + } + return user, nil + } + + // If locking is required, bypass cache and fetch directly from DB + user, _, _, dbErr := s.getUserByPATIDFromDB(ctx, lockStrength, patID) + if dbErr != nil { + err = dbErr + return nil, err + } + return user, nil +} + +func (s *SqlStore) getUserByPATIDFromDB(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, string, string, error) { + ctx, cancel := getDebuggingCtx(ctx) + defer cancel() + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -640,21 +730,43 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren Where("personal_access_tokens.id = ?", patID).Take(&user) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - err = status.NewPATNotFoundError(patID) - return nil, err + err := status.NewPATNotFoundError(patID) + return nil, "", "", err } log.WithContext(ctx).Errorf("failed to get token user from the store: %s", result.Error) - err = status.NewGetUserFromStoreError() - return nil, err + err := status.NewGetUserFromStoreError() + return nil, "", "", err } - return &user, nil + return &user, patID, user.AccountID, nil } func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) { var err error defer s.trackStoreOperation(time.Now(), "GetUserByUserID", &err)() + // If no locking is required, try to get from cache first + if lockStrength == LockingStrengthNone && s.userCache != nil { + user, cacheErr := s.userCache.GetOrSet(ctx, userID, func() (*types.User, string, string, error) { + return s.getUserByUserIDFromDB(ctx, lockStrength, userID) + }) + if cacheErr != nil { + err = cacheErr + return nil, err + } + return user, nil + } + + // If locking is required, bypass cache and fetch directly from DB + user, _, _, dbErr := s.getUserByUserIDFromDB(ctx, lockStrength, userID) + if dbErr != nil { + err = dbErr + return nil, err + } + return user, nil +} + +func (s *SqlStore) getUserByUserIDFromDB(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, string, string, error) { ctx, cancel := getDebuggingCtx(ctx) defer cancel() @@ -667,14 +779,13 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - err = status.NewUserNotFoundError(userID) - return nil, err + return nil, "", "", status.NewUserNotFoundError(userID) } - err = status.NewGetUserFromStoreError() - return nil, err + return nil, "", "", status.NewGetUserFromStoreError() } - return &user, nil + // Return user, accountID (secondary key), and no error + return &user, "", user.AccountID, nil } func (s *SqlStore) DeleteUser(ctx context.Context, accountID, userID string) error { @@ -695,6 +806,8 @@ func (s *SqlStore) DeleteUser(ctx context.Context, accountID, userID string) err return err } + s.userCache.InvalidateByPrimaryKey(ctx, userID) + return nil } @@ -1085,16 +1198,40 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) var err error defer s.trackStoreOperation(time.Now(), "GetAccountIDByPeerPubKey", &err)() + // Try to get from cache first + if s.accountIDCache != nil { + accountID, cacheErr := s.accountIDCache.GetOrSet(ctx, peerKey, func() (string, string, error) { + accountId, err := s.getAccountIDByPeerPubKeyFromDB(ctx, peerKey) + if err != nil { + return "", "", err + } + return accountId, accountId, nil + }) + if cacheErr != nil { + err = cacheErr + return "", err + } + return accountID, nil + } + + // Fallback to direct DB query if cache is not available + accountID, dbErr := s.getAccountIDByPeerPubKeyFromDB(ctx, peerKey) + if dbErr != nil { + err = dbErr + return "", err + } + return accountID, nil +} + +func (s *SqlStore) getAccountIDByPeerPubKeyFromDB(ctx context.Context, peerKey string) (string, error) { var peer nbpeer.Peer var accountID string result := s.db.Model(&peer).Select("account_id").Where(GetKeyQueryCondition(s), peerKey).Take(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - err = status.Errorf(status.NotFound, "account not found: index lookup failed") - return "", err + return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - err = status.NewGetAccountFromStoreError(result.Error) - return "", err + return "", status.NewGetAccountFromStoreError(result.Error) } return accountID, nil @@ -1265,6 +1402,28 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking var err error defer s.trackStoreOperation(time.Now(), "GetPeerByPeerPubKey", &err)() + // If no locking is required, try to get from cache first + if lockStrength == LockingStrengthNone && s.peerCache != nil { + peer, cacheErr := s.peerCache.GetOrSet(ctx, peerKey, func() (*nbpeer.Peer, string, string, error) { + return s.getPeerByPeerPubKeyFromDB(ctx, lockStrength, peerKey) + }) + if cacheErr != nil { + err = cacheErr + return nil, err + } + return peer, nil + } + + // If locking is required, bypass cache and fetch directly from DB + peer, _, _, dbErr := s.getPeerByPeerPubKeyFromDB(ctx, lockStrength, peerKey) + if dbErr != nil { + err = dbErr + return nil, err + } + return peer, nil +} + +func (s *SqlStore) getPeerByPeerPubKeyFromDB(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, string, string, error) { ctx, cancel := getDebuggingCtx(ctx) defer cancel() @@ -1278,33 +1437,51 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - err = status.NewPeerNotFoundError(peerKey) - return nil, err + return nil, "", "", status.NewPeerNotFoundError(peerKey) } - err = status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error) - return nil, err + return nil, "", "", status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error) } - return &peer, nil + return &peer, peer.ID, peer.AccountID, nil } func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) { var err error defer s.trackStoreOperation(time.Now(), "GetAccountSettings", &err)() + // If no locking is required, try to get from cache first + if lockStrength == LockingStrengthNone && s.settingsCache != nil { + settings, cacheErr := s.settingsCache.GetOrSet(ctx, accountID, func() (*types.Settings, error) { + return s.getAccountSettingsFromDB(ctx, lockStrength, accountID) + }) + if cacheErr != nil { + err = cacheErr + return nil, err + } + return settings, nil + } + + // If locking is required, bypass cache and fetch directly from DB + settings, dbErr := s.getAccountSettingsFromDB(ctx, lockStrength, accountID) + if dbErr != nil { + err = dbErr + return nil, err + } + return settings, nil +} + +func (s *SqlStore) getAccountSettingsFromDB(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) { tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } var accountSettings types.AccountSettings - if err = tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountSettings).Error; err != nil { + if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountSettings).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - err = status.Errorf(status.NotFound, "settings not found") - return nil, err + return nil, status.Errorf(status.NotFound, "settings not found") } - err = status.Errorf(status.Internal, "issue getting settings from store: %s", err) - return nil, err + return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err) } return accountSettings.Settings, nil } @@ -1358,6 +1535,8 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri return err } + s.userCache.InvalidateByPrimaryKey(ctx, userID) + return nil } @@ -1881,6 +2060,28 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength var err error defer s.trackStoreOperation(time.Now(), "GetPeerByID", &err)() + // If no locking is required, try to get from cache first + if lockStrength == LockingStrengthNone && s.peerCache != nil { + peer, cacheErr := s.peerCache.GetOrSetBySecondaryKey(ctx, peerID, func() (*nbpeer.Peer, string, string, error) { + return s.getPeerByIDFromDB(ctx, lockStrength, accountID, peerID) + }) + if cacheErr != nil { + err = cacheErr + return nil, err + } + return peer, nil + } + + // If locking is required, bypass cache and fetch directly from DB + peer, _, _, dbErr := s.getPeerByIDFromDB(ctx, lockStrength, accountID, peerID) + if dbErr != nil { + err = dbErr + return nil, err + } + return peer, nil +} + +func (s *SqlStore) getPeerByIDFromDB(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, string, string, error) { tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -1891,14 +2092,12 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength Take(&peer, accountAndIDQueryCondition, accountID, peerID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - err = status.NewPeerNotFoundError(peerID) - return nil, err + return nil, "", "", status.NewPeerNotFoundError(peerID) } - err = status.Errorf(status.Internal, "failed to get peer from store") - return nil, err + return nil, "", "", status.Errorf(status.Internal, "failed to get peer from store") } - return peer, nil + return peer, peer.Key, peer.AccountID, nil } // GetPeersByIDs retrieves peers by their IDs and account ID. @@ -2017,6 +2216,9 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri return err } + s.peerCache.InvalidateByPrimaryKey(ctx, peerID) + s.accountIDCache.InvalidateBySecondaryKey(ctx, accountID) + return nil } @@ -2061,9 +2263,14 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor func (s *SqlStore) withTx(tx *gorm.DB) Store { return &SqlStore{ - db: tx, - metrics: s.metrics, - storeEngine: s.storeEngine, + db: tx, + metrics: s.metrics, + storeEngine: s.storeEngine, + peerCache: s.peerCache, + userCache: s.userCache, + accountIDCache: s.accountIDCache, + domainDataCache: s.domainDataCache, + settingsCache: s.settingsCache, } } @@ -2124,6 +2331,28 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength var err error defer s.trackStoreOperation(time.Now(), "GetAccountDomainAndCategory", &err)() + // If no locking is required, try to get from cache first + if lockStrength == LockingStrengthNone && s.domainDataCache != nil { + domainData, cacheErr := s.domainDataCache.GetOrSet(ctx, accountID, func() (*DomainData, error) { + return s.getAccountDomainAndCategoryFromDB(ctx, lockStrength, accountID) + }) + if cacheErr != nil { + err = cacheErr + return "", "", err + } + return domainData.Domain, domainData.Category, nil + } + + // If locking is required, bypass cache and fetch directly from DB + domainData, dbErr := s.getAccountDomainAndCategoryFromDB(ctx, lockStrength, accountID) + if dbErr != nil { + err = dbErr + return "", "", err + } + return domainData.Domain, domainData.Category, nil +} + +func (s *SqlStore) getAccountDomainAndCategoryFromDB(ctx context.Context, lockStrength LockingStrength, accountID string) (*DomainData, error) { tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -2134,14 +2363,15 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength Where(idQueryCondition, accountID).Take(&account) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - err = status.Errorf(status.NotFound, "account not found") - return "", "", err + return nil, status.Errorf(status.NotFound, "account not found") } - err = status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error) - return "", "", err + return nil, status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error) } - return account.Domain, account.DomainCategory, nil + return &DomainData{ + Domain: account.Domain, + Category: account.DomainCategory, + }, nil } // GetGroupByID retrieves a group by ID and account ID. @@ -2822,6 +3052,8 @@ func (s *SqlStore) SaveAccountSettings(ctx context.Context, accountID string, se return err } + s.settingsCache.Invalidate(ctx, accountID) + if result.RowsAffected == 0 { err = status.NewAccountNotFoundError(accountID) return err @@ -3421,6 +3653,7 @@ func (s *SqlStore) MarkAccountPrimary(ctx context.Context, accountID string) err err = status.Errorf(status.Internal, "failed to mark account as primary") return err } + s.domainDataCache.Invalidate(ctx, accountID) if result.RowsAffected == 0 { err = status.NewAccountNotFoundError(accountID)