mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Fix a race condition where a concurrent user-issued Up or Down command (#5418)
could interleave with a sleep/wake event causing out-of-order state transitions. The mutex now covers the full duration of each handler including the status check, the Up/Down call, and the flag update. Note: if Up or Down commands are triggered in parallel with sleep/wake events, the overall ordering of up/down/sleep/wake operations is still not guaranteed beyond what the mutex provides within the handler itself.
This commit is contained in:
80
client/internal/sleep/handler/handler.go
Normal file
80
client/internal/sleep/handler/handler.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
type Agent interface {
|
||||
Up(ctx context.Context) error
|
||||
Down(ctx context.Context) error
|
||||
Status() (internal.StatusType, error)
|
||||
}
|
||||
|
||||
type SleepHandler struct {
|
||||
agent Agent
|
||||
|
||||
mu sync.Mutex
|
||||
// sleepTriggeredDown indicates whether the sleep handler triggered the last client down, to avoid unnecessary up on wake
|
||||
sleepTriggeredDown bool
|
||||
}
|
||||
|
||||
func New(agent Agent) *SleepHandler {
|
||||
return &SleepHandler{
|
||||
agent: agent,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SleepHandler) HandleWakeUp(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.sleepTriggeredDown {
|
||||
log.Info("skipping up because wasn't sleep down")
|
||||
return nil
|
||||
}
|
||||
|
||||
// avoid other wakeup runs if sleep didn't make the computer sleep
|
||||
s.sleepTriggeredDown = false
|
||||
|
||||
log.Info("running up after wake up")
|
||||
err := s.agent.Up(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("running up failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("running up command executed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SleepHandler) HandleSleep(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
status, err := s.agent.Status()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if status != internal.StatusConnecting && status != internal.StatusConnected {
|
||||
log.Infof("skipping setting the agent down because status is %s", status)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info("running down after system started sleeping")
|
||||
|
||||
if err = s.agent.Down(ctx); err != nil {
|
||||
log.Errorf("running down failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.sleepTriggeredDown = true
|
||||
|
||||
log.Info("running down executed successfully")
|
||||
return nil
|
||||
}
|
||||
153
client/internal/sleep/handler/handler_test.go
Normal file
153
client/internal/sleep/handler/handler_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
type mockAgent struct {
|
||||
upErr error
|
||||
downErr error
|
||||
statusErr error
|
||||
status internal.StatusType
|
||||
upCalls int
|
||||
}
|
||||
|
||||
func (m *mockAgent) Up(_ context.Context) error {
|
||||
m.upCalls++
|
||||
return m.upErr
|
||||
}
|
||||
|
||||
func (m *mockAgent) Down(_ context.Context) error {
|
||||
return m.downErr
|
||||
}
|
||||
|
||||
func (m *mockAgent) Status() (internal.StatusType, error) {
|
||||
return m.status, m.statusErr
|
||||
}
|
||||
|
||||
func newHandler(status internal.StatusType) (*SleepHandler, *mockAgent) {
|
||||
agent := &mockAgent{status: status}
|
||||
return New(agent), agent
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
|
||||
h, agent := newHandler(internal.StatusIdle)
|
||||
|
||||
err := h.HandleWakeUp(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, agent.upCalls, "Up should not be called when flag is false")
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
|
||||
h, _ := newHandler(internal.StatusIdle)
|
||||
h.sleepTriggeredDown = true
|
||||
|
||||
// Even if Up fails, flag should be reset
|
||||
_ = h.HandleWakeUp(context.Background())
|
||||
|
||||
assert.False(t, h.sleepTriggeredDown, "flag must be reset before calling Up")
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_CallsUpWhenFlagSet(t *testing.T) {
|
||||
h, agent := newHandler(internal.StatusIdle)
|
||||
h.sleepTriggeredDown = true
|
||||
|
||||
err := h.HandleWakeUp(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, agent.upCalls)
|
||||
assert.False(t, h.sleepTriggeredDown)
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_ReturnsErrorFromUp(t *testing.T) {
|
||||
h, agent := newHandler(internal.StatusIdle)
|
||||
h.sleepTriggeredDown = true
|
||||
agent.upErr = errors.New("up failed")
|
||||
|
||||
err := h.HandleWakeUp(context.Background())
|
||||
|
||||
assert.ErrorIs(t, err, agent.upErr)
|
||||
assert.False(t, h.sleepTriggeredDown, "flag should still be reset even when Up fails")
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_SecondCallIsNoOp(t *testing.T) {
|
||||
h, agent := newHandler(internal.StatusIdle)
|
||||
h.sleepTriggeredDown = true
|
||||
|
||||
_ = h.HandleWakeUp(context.Background())
|
||||
err := h.HandleWakeUp(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, agent.upCalls, "second wakeup should be no-op")
|
||||
}
|
||||
|
||||
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status internal.StatusType
|
||||
}{
|
||||
{"Idle", internal.StatusIdle},
|
||||
{"NeedsLogin", internal.StatusNeedsLogin},
|
||||
{"LoginFailed", internal.StatusLoginFailed},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h, _ := newHandler(tt.status)
|
||||
|
||||
err := h.HandleSleep(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.False(t, h.sleepTriggeredDown)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status internal.StatusType
|
||||
}{
|
||||
{"Connecting", internal.StatusConnecting},
|
||||
{"Connected", internal.StatusConnected},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h, _ := newHandler(tt.status)
|
||||
|
||||
err := h.HandleSleep(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, h.sleepTriggeredDown)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSleep_ReturnsErrorFromStatus(t *testing.T) {
|
||||
agent := &mockAgent{statusErr: errors.New("status error")}
|
||||
h := New(agent)
|
||||
|
||||
err := h.HandleSleep(context.Background())
|
||||
|
||||
assert.ErrorIs(t, err, agent.statusErr)
|
||||
assert.False(t, h.sleepTriggeredDown)
|
||||
}
|
||||
|
||||
func TestHandleSleep_ReturnsErrorFromDown(t *testing.T) {
|
||||
agent := &mockAgent{status: internal.StatusConnected, downErr: errors.New("down failed")}
|
||||
h := New(agent)
|
||||
|
||||
err := h.HandleSleep(context.Background())
|
||||
|
||||
assert.ErrorIs(t, err, agent.downErr)
|
||||
assert.False(t, h.sleepTriggeredDown, "flag should not be set when Down fails")
|
||||
}
|
||||
Reference in New Issue
Block a user