[client] Fix state manager race conditions (#2890)

This commit is contained in:
Viktor Liu
2024-11-15 20:05:26 +01:00
committed by GitHub
parent a1c5287b7c
commit 121dfda915
7 changed files with 118 additions and 100 deletions

View File

@@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
type Counter[Key comparable, I, O any] struct {
// refCountMap keeps track of the reference Ref for keys
refCountMap map[Key]Ref[O]
refCountMu sync.Mutex
mu sync.Mutex
// idMap keeps track of the keys associated with an ID for removal
idMap map[string][]Key
idMu sync.Mutex
add AddFunc[Key, I, O]
remove RemoveFunc[Key, O]
}
@@ -75,10 +74,8 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
func (rm *Counter[Key, I, O]) LoadData(
existingCounter *Counter[Key, I, O],
) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
rm.refCountMap = existingCounter.refCountMap
rm.idMap = existingCounter.idMap
@@ -87,8 +84,8 @@ func (rm *Counter[Key, I, O]) LoadData(
// Get retrieves the current reference count and associated data for a key.
// If the key doesn't exist, it returns a zero value Ref and false.
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
ref, ok := rm.refCountMap[key]
return ref, ok
@@ -97,9 +94,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
// Increment increments the reference count for the given key.
// If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
return rm.increment(key, in)
}
func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
ref := rm.refCountMap[key]
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
@@ -126,10 +127,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
// If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
ref, err := rm.Increment(key, in)
ref, err := rm.increment(key, in)
if err != nil {
return ref, fmt.Errorf("with ID: %w", err)
}
@@ -141,9 +142,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O],
// Decrement decrements the reference count for the given key.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
return rm.decrement(key)
}
func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
ref, ok := rm.refCountMap[key]
if !ok {
logCallerF("No reference found for key %v", key)
@@ -168,12 +172,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
// DecrementWithID decrements the reference count for all keys associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
var merr *multierror.Error
for _, key := range rm.idMap[id] {
if _, err := rm.Decrement(key); err != nil {
if _, err := rm.decrement(key); err != nil {
merr = multierror.Append(merr, err)
}
}
@@ -184,10 +188,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
// Flush removes all references and calls RemoveFunc for each key.
func (rm *Counter[Key, I, O]) Flush() error {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
var merr *multierror.Error
for key := range rm.refCountMap {
@@ -206,10 +208,8 @@ func (rm *Counter[Key, I, O]) Flush() error {
// Clear removes all references without calling RemoveFunc.
func (rm *Counter[Key, I, O]) Clear() {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
clear(rm.refCountMap)
clear(rm.idMap)
@@ -217,10 +217,8 @@ func (rm *Counter[Key, I, O]) Clear() {
// MarshalJSON implements the json.Marshaler interface for Counter.
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.mu.Lock()
defer rm.mu.Unlock()
return json.Marshal(struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"`