diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index 75e2dfcb6..50efefa9c 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -241,17 +241,20 @@ func (c *ConnTrack) Close() error { c.mux.Lock() defer c.mux.Unlock() - if c.started { - select { - case c.done <- struct{}{}: - default: - } + if !c.started { + return nil } + select { + case c.done <- struct{}{}: + default: + } + + c.started = false + if c.conn != nil { err := c.conn.Close() c.conn = nil - c.started = false c.RestoreAccounting() diff --git a/client/internal/netflow/conntrack/conntrack_test.go b/client/internal/netflow/conntrack/conntrack_test.go index aa4a1b2b7..dd5b497ba 100644 --- a/client/internal/netflow/conntrack/conntrack_test.go +++ b/client/internal/netflow/conntrack/conntrack_test.go @@ -173,6 +173,55 @@ func TestStopRaceWithReconnectDial(t *testing.T) { ct.mux.Unlock() } +func TestCloseRaceWithReconnectDial(t *testing.T) { + first := newMockListener() + dialStarted := make(chan struct{}) + dialProceed := make(chan struct{}) + second := newMockListener() + callCount := atomic.Int32{} + + ct := &ConnTrack{ + dial: func() (listener, error) { + n := callCount.Add(1) + if n == 1 { + return first, nil + } + close(dialStarted) + <-dialProceed + return second, nil + }, + done: make(chan struct{}, 1), + } + + err := ct.Start(false) + require.NoError(t, err) + + first.errChan <- assert.AnError + + select { + case <-dialStarted: + case <-time.After(15 * time.Second): + t.Fatal("timed out waiting for reconnect dial") + } + + // Close while dial is in progress (conn is nil). + require.NoError(t, ct.Close()) + + close(dialProceed) + + // The second connection should be closed (not leaked). + select { + case <-second.closedCh: + case <-time.After(2 * time.Second): + t.Fatal("second connection was leaked after Close") + } + + ct.mux.Lock() + assert.False(t, ct.started) + assert.Nil(t, ct.conn) + ct.mux.Unlock() +} + func TestStartIsIdempotent(t *testing.T) { mock := newMockListener() callCount := atomic.Int32{}