Compare commits

..

10 Commits

Author SHA1 Message Date
Zoltán Papp
f4e2836d3a [client] clean up all started components on Start failure
Start's failure defer only called close(), which covers the wg interface,
firewall, rosenpass and port forwarding but leaves connMgr, srWatcher,
route/DNS/flow/state managers and the monitor goroutines running. A late
failure (e.g. the context-cancelled check after the signal stream) thus
leaked them.

Extract Stop's locked teardown into stopLocked (caller holds syncMsgMux,
does not wait on shutdownWg) and call it from both Stop and Start's defer.
The defer also cancels the run context first so goroutines started before
the failure unwind. Teardown order is unchanged.
2026-06-16 15:14:47 +02:00
Zoltán Papp
3190347849 [client] abort Start if context cancelled while waiting for signal stream
receiveSignalEvents blocks in WaitStreamConnected until the signal stream
connects or the context is cancelled. If Stop cancelled e.ctx while Start
was parked there, Start kept going: it started the remaining subsystems on
a cancelled context and marked a shutting-down engine as started. Return
the context error from receiveSignalEvents and propagate it from Start, so
the deferred cleanup runs and the cancellation reaches the caller.
2026-06-16 15:03:29 +02:00
Zoltán Papp
90af9dd8ae [client] fix WaitStreamConnected stale-channel race
The StreamConnected check and the wait-channel creation took the mutex
separately, so notifyStreamConnected could set the status and close/clear
connectedCh in between: the waiter then created a fresh channel nobody
would ever close and blocked forever. Also, the status read was unlocked
while notify wrote it under the mutex (a data race). Do the check and the
channel fetch in one locked section; drop the now-unused
getStreamStatusChan helper. Pre-existing bug, not introduced by this branch.
2026-06-16 14:51:17 +02:00
Zoltán Papp
5cf865b243 [client] bound WaitStreamConnected in signal client tests
The tests waited on WaitStreamConnected with context.Background() and the
client's own context was also Background, so a stream that never connects
would hang until the suite timeout. Pass a 5s timeout context and assert
StreamConnected afterwards so the tests fail fast with a clear reason.
2026-06-16 14:47:07 +02:00
Zoltán Papp
67b362b4a4 [client] interrupt connect backoff on context cancel
The run loop retried with a raw ExponentialBackOff, so a backoff sleep
ignored context cancellation. Now that ConnectClient.Stop waits for the
run loop to exit, a cancel landing during a sleep would block Stop for
the full interval (up to MaxInterval). Wrap the backoff with the run
context so Retry returns promptly on cancel; the retry budget itself
(MaxElapsedTime) is unchanged.
2026-06-16 14:42:35 +02:00
Zoltán Papp
32fccdeede [client] init context state in engine tests
Engine tests built the engine context with context.WithCancel(
context.Background()), omitting CtxInitState. Now that the run context
is created in the constructor, the wgIfaceMonitor goroutine can reach
triggerClientRestart during teardown, which calls CtxGetState and
panics on the missing state. Real entry points (up, embed, service)
always CtxInitState; only the tests skipped it.
2026-06-16 14:35:44 +02:00
Zoltán Papp
98c71d7913 [client] fix Start/Stop race by making the run loop own engine shutdown
ConnectClient.Stop stopped the engine directly while the run loop's
backoff cycle could still be starting an engine, so Engine.close raced
Engine.Start (e.g. firewall setup reading wgInterface while close nils
it). embed.Client.Start's rollback only avoided a deadlock by cancelling
before Stop; the race itself remained and was caught by -race.

Make the run loop the sole owner of engine shutdown: derive the run
context in NewConnectClient, and have Stop cancel it and wait for the
loop to exit (skipping the wait when the loop never ran) instead of
calling engine.Stop. The loop now always stops the engine on its way
out, dropping the unsynchronised wgInterface check it used to guard that
call. Self-calls from within the loop use runCancel to avoid waiting on
themselves.

embed keeps a defensive pre-Stop cancel(); the daemon's cleanupConnection
gets a TODO to adopt Stop() rather than stopping the engine in parallel.
2026-06-16 13:57:41 +02:00
Zoltán Papp
002e0b036f [client] let engine context unblock WaitStreamConnected
WaitStreamConnected only watched the signal client's own context, which
derives from the parent engineCtx rather than the engine's run context.
A Start blocked here (signal stream not yet up) could therefore not be
released by Engine.Stop, since Stop only cancels the engine's run
context.

Pass a context into WaitStreamConnected and select on it too, and have
the engine pass e.ctx, so Stop cancelling e.ctx unblocks a parked Start.
Update the Client interface, the mock, and callers accordingly.
2026-06-16 13:11:46 +02:00
Zoltán Papp
c370c72d93 [client] make Engine single-use and guard against double Start
Create the run context once in NewEngine instead of in Start. This
keeps e.cancel valid for the engine's whole lifetime, so Stop can
cancel a Start that is blocked waiting on the network while holding
syncMsgMux: Stop now cancels before taking the lock, unblocking that
Start so it can release the mutex.

Reject re-entry into Start: a non-nil wgInterface means a prior Start
already ran (ErrEngineAlreadyStarted), and a cancelled run context
means the engine was stopped (ErrEngineAlreadyStopped). Both checks run
before the cleanup defer so a duplicate call cannot tear down the
running engine's state.
2026-06-16 13:11:46 +02:00
Zoltán Papp
5895a39380 [client] always clean up on Engine.Start failure via defer
The rosenpass init paths (NewManager/Run) returned without calling
e.close(), leaking the WireGuard interface and other partially
initialized state on failure. Per-branch cleanup was easy to miss when
adding new early returns.

Convert Start to a named error return and tear down via a single defer
that calls e.close() whenever err != nil, removing the scattered
per-branch close() calls (including the redundant one in initFirewall).
2026-06-16 11:54:08 +02:00
16 changed files with 172 additions and 292 deletions

View File

@@ -279,9 +279,11 @@ func (c *Client) Start(startCtx context.Context) error {
select {
case <-startCtx.Done():
// Cancel the client context before stopping: Engine.Start blocks on the
// signal stream while holding the engine mutex and only unblocks on
// cancellation. Stopping first would deadlock on that mutex.
// ConnectClient.Stop now cancels its own run context and waits for the
// run loop to tear the engine down, so this cancel() is no longer
// required to break the deadlock and could be removed. It is kept as a
// defensive belt-and-suspenders: cancelling the parent context first
// guarantees the run loop is unblocked even if Stop's contract regresses.
cancel()
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())

View File

@@ -11,6 +11,7 @@ import (
"runtime/debug"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/cenkalti/backoff/v4"
@@ -54,6 +55,10 @@ var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath
type ConnectClient struct {
ctx context.Context
runCancel context.CancelFunc
runExited chan struct{}
runOnce sync.Once
runStarted atomic.Bool
config *profilemanager.Config
statusRecorder *peer.Status
@@ -70,8 +75,14 @@ func NewConnectClient(
config *profilemanager.Config,
statusRecorder *peer.Status,
) *ConnectClient {
// Derive the run context here so Stop owns the cancel that unblocks the run
// loop. runCancel is set once at construction, so Stop can call it without
// racing the run loop's startup. Callers therefore need not cancel before Stop.
runCtx, runCancel := context.WithCancel(ctx)
return &ConnectClient{
ctx: ctx,
ctx: runCtx,
runCancel: runCancel,
runExited: make(chan struct{}),
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
@@ -132,6 +143,11 @@ func (c *ConnectClient) RunOniOS(
}
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
// Mark the loop as started and signal exit on return so Stop can wait for
// the loop to finish (and skip the wait if the loop never ran).
c.runStarted.Store(true)
defer c.runOnce.Do(func() { close(c.runExited) })
defer func() {
if r := recover(); r != nil {
rec := c.statusRecorder
@@ -287,7 +303,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin)
_ = c.Stop()
c.runCancel()
return backoff.Permanent(wrapErr(err)) // unrecoverable error
}
return wrapErr(err)
@@ -407,14 +423,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.engine = nil
c.engineMutex.Unlock()
// todo: consider to remove this condition. Is not thread safe.
// We should always call Stop(), but we need to verify that it is idempotent
if engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
log.Infof("ensuring wg interface is removed, Netbird engine context cancelled")
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
c.statusRecorder.ClientTeardown()
@@ -430,12 +442,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
c.statusRecorder.ClientStart()
err = backoff.Retry(operation, backOff)
err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx))
if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin)
_ = c.Stop()
c.runCancel()
}
return err
}
@@ -513,11 +525,9 @@ func (c *ConnectClient) Status() StatusType {
}
func (c *ConnectClient) Stop() error {
engine := c.Engine()
if engine != nil {
if err := engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
c.runCancel()
if c.runStarted.Load() {
<-c.runExited
}
return nil
}

View File

@@ -207,35 +207,3 @@ func FormatAnswers(answers []dns.RR) string {
}
return "[" + strings.Join(parts, ", ") + "]"
}
// StripOPT removes any OPT pseudo-RRs from the message's Extra section. Per
// RFC 6891 a responder must not include an OPT RR toward a client that did not
// advertise EDNS0.
func StripOPT(msg *dns.Msg) {
if len(msg.Extra) == 0 {
return
}
out := msg.Extra[:0]
for _, rr := range msg.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
msg.Extra = out
}
// ExtractEDE returns the first Extended DNS Error (RFC 8914) option carried in
// the message, if present.
func ExtractEDE(msg *dns.Msg) (*dns.EDNS0_EDE, bool) {
opt := msg.IsEdns0()
if opt == nil {
return nil, false
}
for _, o := range opt.Option {
if ede, ok := o.(*dns.EDNS0_EDE); ok {
return ede, true
}
}
return nil, false
}

View File

@@ -120,42 +120,3 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
StripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}
func TestExtractEDE(t *testing.T) {
t.Run("no edns", func(t *testing.T) {
_, ok := ExtractEDE(&dns.Msg{})
assert.False(t, ok, "message without OPT has no EDE")
})
t.Run("edns without ede", func(t *testing.T) {
rm := &dns.Msg{}
rm.SetEdns0(4096, false)
_, ok := ExtractEDE(rm)
assert.False(t, ok, "OPT without EDE option returns false")
})
t.Run("with ede", func(t *testing.T) {
rm := &dns.Msg{}
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: 49152, ExtraText: "upstream timeout"})
rm.Extra = append(rm.Extra, opt)
ede, ok := ExtractEDE(rm)
assert.True(t, ok, "EDE option should be found")
assert.Equal(t, uint16(49152), ede.InfoCode)
assert.Equal(t, "upstream timeout", ede.ExtraText)
})
}

View File

@@ -451,7 +451,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
if code, ok := nonRetryableEDE(rm); ok {
if !hadEdns {
resutil.StripOPT(rm)
stripOPT(rm)
}
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
@@ -462,7 +462,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
}
if !hadEdns {
resutil.StripOPT(rm)
stripOPT(rm)
}
u.markUpstreamOk(upstream)
@@ -520,6 +520,22 @@ func upstreamUDPSize() uint16 {
return dns.MinMsgSize
}
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
// the response complies with RFC 6891 when the client did not advertise EDNS0.
func stripOPT(rm *dns.Msg) {
if len(rm.Extra) == 0 {
return
}
out := rm.Extra[:0]
for _, rr := range rm.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
rm.Extra = out
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
return &upstreamFailure{upstream: upstream, reason: err.Error()}

View File

@@ -913,6 +913,19 @@ func TestEDEName(t *testing.T) {
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
stripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")

View File

@@ -26,15 +26,6 @@ import (
const errResolveFailed = "failed to resolve query for domain=%s: %v"
const upstreamTimeout = 15 * time.Second
// EDE info codes the forwarder emits on upstream failures so the querying
// client can see the reason without inspecting this peer's logs. They live in
// the RFC 8914 Private Use range (49152-65535); the Go resolver never exposes a
// real upstream EDE here, so these cannot collide with a genuine code.
const (
edeNetbirdUpstreamTimeout uint16 = 49152
edeNetbirdUpstreamFailure uint16 = 49153
)
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
@@ -229,7 +220,7 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
return
}
@@ -342,7 +333,6 @@ func (f *DNSForwarder) handleDNSError(
resp *dns.Msg,
domain string,
result resutil.LookupResult,
reqHasEdns bool,
startTime time.Time,
) {
qType := question.Qtype
@@ -384,10 +374,6 @@ func (f *DNSForwarder) handleDNSError(
logger.Warnf(errResolveFailed, domain, result.Err)
}
if reqHasEdns {
attachEDE(resp, edeCodeFor(dnsErr), edeText(dnsErr))
}
f.writeResponse(logger, w, resp, domain, startTime)
}
@@ -428,33 +414,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
return selectedResId, matches
}
// edeCodeFor maps an upstream lookup error to the NetBird EDE info code.
func edeCodeFor(dnsErr *net.DNSError) uint16 {
if dnsErr != nil && dnsErr.IsTimeout {
return edeNetbirdUpstreamTimeout
}
return edeNetbirdUpstreamFailure
}
// edeText builds the EDE extra-text describing the class of upstream failure.
// It deliberately omits the upstream server address, which may be an internal
// resolver and is exposed to any client permitted to use the route; the full
// detail stays in the forwarder's local log.
func edeText(dnsErr *net.DNSError) string {
if dnsErr != nil && dnsErr.IsTimeout {
return "netbird forwarder: upstream timeout"
}
return "netbird forwarder: upstream failure"
}
// attachEDE adds an Extended DNS Error (RFC 8914) option to the response,
// creating the OPT pseudo-record if the response does not already carry one.
func attachEDE(resp *dns.Msg, code uint16, text string) {
opt := resp.IsEdns0()
if opt == nil {
resp.SetEdns0(dns.DefaultMsgSize, false)
opt = resp.IsEdns0()
}
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: code, ExtraText: text})
}

View File

@@ -16,7 +16,6 @@ import (
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/route"
@@ -618,85 +617,6 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
}
}
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
tests := []struct {
name string
lookupErr error
reqEdns bool
wantEDE bool
wantCode uint16
wantTextHas string
}{
{
name: "timeout with edns0",
lookupErr: &net.DNSError{Err: "i/o timeout", Server: "10.0.0.53:53", IsTimeout: true},
reqEdns: true,
wantEDE: true,
wantCode: edeNetbirdUpstreamTimeout,
wantTextHas: "netbird forwarder: upstream timeout",
},
{
name: "server failure with edns0",
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
reqEdns: true,
wantEDE: true,
wantCode: edeNetbirdUpstreamFailure,
wantTextHas: "netbird forwarder: upstream failure",
},
{
name: "no edns0 in request omits ede",
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
reqEdns: false,
wantEDE: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = mockResolver
d, err := domain.FromString("example.com")
require.NoError(t, err)
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr(nil), tt.lookupErr).Once()
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
if tt.reqEdns {
query.SetEdns0(dns.DefaultMsgSize, false)
}
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
mockResolver.AssertExpectations(t)
require.NotNil(t, writtenResp, "expected a response")
assert.Equal(t, dns.RcodeServerFailure, writtenResp.Rcode, "upstream failure must be SERVFAIL")
ede, ok := resutil.ExtractEDE(writtenResp)
if !tt.wantEDE {
assert.False(t, ok, "response must not carry EDE")
return
}
require.True(t, ok, "response must carry EDE")
assert.Equal(t, tt.wantCode, ede.InfoCode, "EDE info code")
assert.Contains(t, ede.ExtraText, tt.wantTextHas, "EDE extra-text")
assert.NotContains(t, ede.ExtraText, "10.0.0.53", "must not leak upstream server address")
})
}
}
func TestDNSForwarder_TCPTruncation(t *testing.T) {
// Test that large UDP responses are truncated with TC bit set
mockResolver := &MockResolver{}

View File

@@ -86,6 +86,8 @@ const (
var ErrResetConnection = fmt.Errorf("reset connection")
var ErrEngineAlreadyStarted = errors.New("engine already started")
type EngineConfig struct {
WgPort int
WgIfaceName string
@@ -199,6 +201,8 @@ type Engine struct {
ctx context.Context
cancel context.CancelFunc
started bool
wgInterface WGIface
udpMux *udpmux.UniversalUDPMuxDefault
@@ -279,9 +283,15 @@ func NewEngine(
services EngineServices,
mobileDep MobileDependency,
) *Engine {
// The engine is single-use: a fresh instance is built per connection
// cycle (see Client.run), so the run context is created once here rather
// than in Start.
ctx, cancel := context.WithCancel(clientCtx)
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
ctx: ctx,
cancel: cancel,
signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient,
@@ -314,8 +324,34 @@ func (e *Engine) Stop() error {
log.Debugf("tried stopping engine that is nil")
return nil
}
e.cancel()
e.syncMsgMux.Lock()
e.stopLocked()
e.syncMsgMux.Unlock()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}
log.Infof("stopped Netbird Engine")
return nil
}
// stopLocked tears down everything Start may have brought up, in the order
// teardown requires (DNS before the interface goes down, flow manager after).
// The caller must hold syncMsgMux. It is shared by Stop and by Start's failure
// path, so a partially-initialized engine is cleaned up the same way; every
// step is nil-guarded. It does not wait on shutdownWg — the caller does that
// after releasing the lock, since the goroutines also take syncMsgMux.
func (e *Engine) stopLocked() {
if e.connMgr != nil {
e.connMgr.Close()
}
@@ -366,10 +402,6 @@ func (e *Engine) Stop() error {
// so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.cancel != nil {
e.cancel()
}
e.jobExecutorWG.Wait() // block until job goroutines finish
e.close()
@@ -388,21 +420,6 @@ func (e *Engine) Stop() error {
if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
e.syncMsgMux.Unlock()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}
log.Infof("stopped Netbird Engine")
return nil
}
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
@@ -440,18 +457,38 @@ func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) (err error) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if err := iface.ValidateMTU(e.config.MTU); err != nil {
// The engine is single-use. Reject a duplicate start and a start on an
// already-stopped engine (run context cancelled).
if e.started {
return ErrEngineAlreadyStarted
}
if ctxErr := e.ctx.Err(); ctxErr != nil {
return fmt.Errorf("engine already stopped: %w", ctxErr)
}
e.started = true
// Tear down any partially-initialized state on a failed start. Cancel the
// run context first so goroutines started before the failure (connMgr,
// srWatcher, monitors) unwind, then stopLocked mirrors Stop's teardown (we
// already hold syncMsgMux), cleaning up route/DNS/flow/state managers too,
// not just what close() covers.
defer func() {
if err != nil {
e.cancel()
e.stopLocked()
}
}()
if err = iface.ValidateMTU(e.config.MTU); err != nil {
return fmt.Errorf("invalid MTU configuration: %w", err)
}
if e.cancel != nil {
e.cancel()
}
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
wgIface, err := e.newWgIface()
@@ -485,13 +522,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
if err != nil {
e.close()
return fmt.Errorf("read initial settings: %w", err)
}
dnsServer, err := e.newDnsServer(dnsConfig)
if err != nil {
e.close()
return fmt.Errorf("create dns server: %w", err)
}
e.dnsServer = dnsServer
@@ -526,7 +561,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
if err = e.wgInterfaceCreate(); err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
e.close()
return fmt.Errorf("create wg interface: %w", err)
}
@@ -535,7 +569,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
}
if err := e.createFirewall(); err != nil {
e.close()
return err
}
@@ -547,7 +580,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.udpMux, err = e.wgInterface.Up()
if err != nil {
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
e.close()
return fmt.Errorf("up wg interface: %w", err)
}
@@ -572,9 +604,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.acl = acl.NewDefaultManager(e.firewall)
}
err = e.dnsServer.Initialize()
if err != nil {
e.close()
if err := e.dnsServer.Initialize(); err != nil {
return fmt.Errorf("initialize dns server: %w", err)
}
@@ -586,7 +616,9 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start(peer.IsForceRelayed())
e.receiveSignalEvents()
if err = e.receiveSignalEvents(); err != nil {
return err
}
e.receiveManagementEvents()
e.receiveJobEvents()
@@ -638,7 +670,6 @@ func (e *Engine) createFirewall() error {
func (e *Engine) initFirewall() error {
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
e.close()
return fmt.Errorf("set firewall: %w", err)
}
@@ -1698,7 +1729,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
}
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
func (e *Engine) receiveSignalEvents() {
func (e *Engine) receiveSignalEvents() error {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
@@ -1762,7 +1793,12 @@ func (e *Engine) receiveSignalEvents() {
}
}()
e.signal.WaitStreamConnected()
// todo: consider to remove this blocker. I do not see benefit to block the Start operations
e.signal.WaitStreamConnected(e.ctx)
if err := e.ctx.Err(); err != nil {
return fmt.Errorf("wait for signal stream: %w", err)
}
return nil
}
func (e *Engine) parseNATExternalIPMappings() []string {

View File

@@ -247,7 +247,7 @@ func TestEngine_SSH(t *testing.T) {
return
}
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
@@ -426,7 +426,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
return
}
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
@@ -638,7 +638,7 @@ func TestEngine_Sync(t *testing.T) {
return
}
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
// feed updates to Engine via mocked Management client
@@ -817,7 +817,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
return
}
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
@@ -1024,7 +1024,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
return
}
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
wgIfaceName := fmt.Sprintf("utun%d", 104+n)

View File

@@ -251,14 +251,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
r.MsgHdr.AuthenticatedData = true
}
// Advertise EDNS0 to the forwarder so it may return an Extended DNS Error
// describing why a lookup failed. The OPT is stripped from the reply when
// the original client did not request EDNS0.
hadEdns := r.IsEdns0() != nil
if !hadEdns {
r.SetEdns0(dns.DefaultMsgSize, false)
}
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
@@ -268,13 +260,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
if ede, ok := resutil.ExtractEDE(reply); ok {
resutil.SetMeta(w, "ede", fmt.Sprintf("%d %s", ede.InfoCode, ede.ExtraText))
}
if !hadEdns {
resutil.StripOPT(reply)
}
resutil.SetMeta(w, "peer", peerKey)
reply.Id = r.Id

View File

@@ -988,6 +988,10 @@ func (s *Server) cleanupConnection() error {
return nil
}
// TODO: consider calling s.connectClient.Stop() instead of engine.Stop().
// actCancel() lets the run loop stop the engine too, so both stop it
// concurrently; ConnectClient.Stop cancels and waits for the run loop,
// making the run loop the sole owner of engine shutdown.
if engine != nil {
if err := engine.Stop(); err != nil {
return err

View File

@@ -33,7 +33,7 @@ type Client interface {
Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error
Ready() bool
IsHealthy() bool
WaitStreamConnected()
WaitStreamConnected(context.Context)
SendToStream(msg *proto.EncryptedMessage) error
Send(msg *proto.Message) error
SetOnReconnectedListener(func())

View File

@@ -65,7 +65,10 @@ var _ = Describe("GrpcClient", func() {
return
}
}()
clientA.WaitStreamConnected()
ctxA, cancelA := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelA()
clientA.WaitStreamConnected(ctxA)
Expect(clientA.StreamConnected()).To(BeTrue())
// connect PeerB to Signal
keyB, _ := wgtypes.GenerateKey()
@@ -91,7 +94,10 @@ var _ = Describe("GrpcClient", func() {
}
}()
clientB.WaitStreamConnected()
ctxB, cancelB := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelB()
clientB.WaitStreamConnected(ctxB)
Expect(clientB.StreamConnected()).To(BeTrue())
// PeerA initiates ping-pong
err := clientA.Send(&sigProto.Message{
@@ -129,8 +135,10 @@ var _ = Describe("GrpcClient", func() {
return
}
}()
client.WaitStreamConnected()
Expect(client).NotTo(BeNil())
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
client.WaitStreamConnected(ctx)
Expect(client.StreamConnected()).To(BeTrue())
})
})

View File

@@ -213,15 +213,6 @@ func (c *GrpcClient) notifyStreamConnected() {
}
}
func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
c.mux.Lock()
defer c.mux.Unlock()
if c.connectedCh == nil {
c.connectedCh = make(chan struct{})
}
return c.connectedCh
}
func (c *GrpcClient) connect(ctx context.Context, key string) (proto.SignalExchange_ConnectStreamClient, error) {
c.stream = nil
@@ -282,14 +273,24 @@ func (c *GrpcClient) IsHealthy() bool {
}
// WaitStreamConnected waits until the client is connected to the Signal stream
func (c *GrpcClient) WaitStreamConnected() {
func (c *GrpcClient) WaitStreamConnected(ctx context.Context) {
// Check the status and obtain the wait channel atomically: otherwise
// notifyStreamConnected could flip the status and close/clear the channel
// between the check and the channel creation, leaving us waiting forever on
// a stale channel.
c.mux.Lock()
if c.status == StreamConnected {
c.mux.Unlock()
return
}
if c.connectedCh == nil {
c.connectedCh = make(chan struct{})
}
ch := c.connectedCh
c.mux.Unlock()
ch := c.getStreamStatusChan()
select {
case <-ctx.Done():
case <-c.ctx.Done():
case <-ch:
}

View File

@@ -55,7 +55,7 @@ func (sm *MockClient) Ready() bool {
return sm.ReadyFunc()
}
func (sm *MockClient) WaitStreamConnected() {
func (sm *MockClient) WaitStreamConnected(context.Context) {
if sm.WaitStreamConnectedFunc == nil {
return
}