From 9ba067391fc8e5a35ad20c9a5f6b03178ab60eb2 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sat, 3 Jan 2026 09:10:02 +0100 Subject: [PATCH] [client] Fix semaphore slot leaks (#5018) - Remove WaitGroup, make SemaphoreGroup a pure semaphore - Make Add() return error instead of silently failing on context cancel - Remove context parameter from Done() to prevent slot leaks - Fix missing Done() call in conn.go error path --- client/internal/peer/conn.go | 9 +- util/semaphore-group/semaphore_group.go | 29 +---- util/semaphore-group/semaphore_group_test.go | 128 +++++++++++-------- 3 files changed, 89 insertions(+), 77 deletions(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 426c31e1a..20a2eb342 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -148,13 +148,15 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { // It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will // be used. func (conn *Conn) Open(engineCtx context.Context) error { - conn.semaphore.Add(engineCtx) + if err := conn.semaphore.Add(engineCtx); err != nil { + return err + } conn.mu.Lock() defer conn.mu.Unlock() if conn.opened { - conn.semaphore.Done(engineCtx) + conn.semaphore.Done() return nil } @@ -165,6 +167,7 @@ func (conn *Conn) Open(engineCtx context.Context) error { relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) if err != nil { + conn.semaphore.Done() return err } conn.workerICE = workerICE @@ -200,7 +203,7 @@ func (conn *Conn) Open(engineCtx context.Context) error { defer conn.wg.Done() conn.waitInitialRandomSleepTime(conn.ctx) - conn.semaphore.Done(conn.ctx) + conn.semaphore.Done() conn.guard.Start(conn.ctx, conn.onGuardEvent) }() diff --git a/util/semaphore-group/semaphore_group.go b/util/semaphore-group/semaphore_group.go index ad74e1bfc..462300672 100644 --- a/util/semaphore-group/semaphore_group.go +++ b/util/semaphore-group/semaphore_group.go @@ -2,12 +2,10 @@ package semaphoregroup import ( "context" - "sync" ) // SemaphoreGroup is a custom type that combines sync.WaitGroup and a semaphore. type SemaphoreGroup struct { - waitGroup sync.WaitGroup semaphore chan struct{} } @@ -18,31 +16,18 @@ func NewSemaphoreGroup(limit int) *SemaphoreGroup { } } -// Add increments the internal WaitGroup counter and acquires a semaphore slot. -func (sg *SemaphoreGroup) Add(ctx context.Context) { - sg.waitGroup.Add(1) - +// Add acquire a slot +func (sg *SemaphoreGroup) Add(ctx context.Context) error { // Acquire semaphore slot select { case <-ctx.Done(): - return + return ctx.Err() case sg.semaphore <- struct{}{}: + return nil } } -// Done decrements the internal WaitGroup counter and releases a semaphore slot. -func (sg *SemaphoreGroup) Done(ctx context.Context) { - sg.waitGroup.Done() - - // Release semaphore slot - select { - case <-ctx.Done(): - return - case <-sg.semaphore: - } -} - -// Wait waits until the internal WaitGroup counter is zero. -func (sg *SemaphoreGroup) Wait() { - sg.waitGroup.Wait() +// Done releases a slot. Must be called after a successful Add. +func (sg *SemaphoreGroup) Done() { + <-sg.semaphore } diff --git a/util/semaphore-group/semaphore_group_test.go b/util/semaphore-group/semaphore_group_test.go index d4491cf77..9406da4a0 100644 --- a/util/semaphore-group/semaphore_group_test.go +++ b/util/semaphore-group/semaphore_group_test.go @@ -2,65 +2,89 @@ package semaphoregroup import ( "context" + "sync" "testing" "time" ) func TestSemaphoreGroup(t *testing.T) { - semGroup := NewSemaphoreGroup(2) - - for i := 0; i < 5; i++ { - semGroup.Add(context.Background()) - go func(id int) { - defer semGroup.Done(context.Background()) - - got := len(semGroup.semaphore) - if got == 0 { - t.Errorf("Expected semaphore length > 0 , got 0") - } - - time.Sleep(time.Millisecond) - t.Logf("Goroutine %d is running\n", id) - }(i) - } - - semGroup.Wait() - - want := 0 - got := len(semGroup.semaphore) - if got != want { - t.Errorf("Expected semaphore length %d, got %d", want, got) - } -} - -func TestSemaphoreGroupContext(t *testing.T) { semGroup := NewSemaphoreGroup(1) - semGroup.Add(context.Background()) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + _ = semGroup.Add(context.Background()) + + ctxTimeout, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) t.Cleanup(cancel) - rChan := make(chan struct{}) - go func() { - semGroup.Add(ctx) - rChan <- struct{}{} - }() - select { - case <-rChan: - case <-time.NewTimer(2 * time.Second).C: - t.Error("Adding to semaphore group should not block when context is not done") - } - - semGroup.Done(context.Background()) - - ctxDone, cancelDone := context.WithTimeout(context.Background(), 1*time.Second) - t.Cleanup(cancelDone) - go func() { - semGroup.Done(ctxDone) - rChan <- struct{}{} - }() - select { - case <-rChan: - case <-time.NewTimer(2 * time.Second).C: - t.Error("Releasing from semaphore group should not block when context is not done") + if err := semGroup.Add(ctxTimeout); err == nil { + t.Error("Adding to semaphore group should not block") + } +} + +func TestSemaphoreGroupFreeUp(t *testing.T) { + semGroup := NewSemaphoreGroup(1) + _ = semGroup.Add(context.Background()) + semGroup.Done() + + ctxTimeout, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + t.Cleanup(cancel) + if err := semGroup.Add(ctxTimeout); err != nil { + t.Error(err) + } +} + +func TestSemaphoreGroupCanceledContext(t *testing.T) { + semGroup := NewSemaphoreGroup(1) + _ = semGroup.Add(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + if err := semGroup.Add(ctx); err == nil { + t.Error("Add should return error when context is already canceled") + } +} + +func TestSemaphoreGroupCancelWhileWaiting(t *testing.T) { + semGroup := NewSemaphoreGroup(1) + _ = semGroup.Add(context.Background()) + + ctx, cancel := context.WithCancel(context.Background()) + errChan := make(chan error, 1) + + go func() { + errChan <- semGroup.Add(ctx) + }() + + time.Sleep(10 * time.Millisecond) + cancel() + + if err := <-errChan; err == nil { + t.Error("Add should return error when context is canceled while waiting") + } +} + +func TestSemaphoreGroupHighConcurrency(t *testing.T) { + const limit = 10 + const numGoroutines = 100 + + semGroup := NewSemaphoreGroup(limit) + var wg sync.WaitGroup + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := semGroup.Add(context.Background()); err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + time.Sleep(time.Millisecond) + semGroup.Done() + }() + } + + wg.Wait() + + // Verify all slots were released + if got := len(semGroup.semaphore); got != 0 { + t.Errorf("Expected semaphore to be empty, got %d slots occupied", got) } }