Compare commits

..

10 Commits

Author SHA1 Message Date
Zoltán Papp
f4e2836d3a [client] clean up all started components on Start failure
Start's failure defer only called close(), which covers the wg interface,
firewall, rosenpass and port forwarding but leaves connMgr, srWatcher,
route/DNS/flow/state managers and the monitor goroutines running. A late
failure (e.g. the context-cancelled check after the signal stream) thus
leaked them.

Extract Stop's locked teardown into stopLocked (caller holds syncMsgMux,
does not wait on shutdownWg) and call it from both Stop and Start's defer.
The defer also cancels the run context first so goroutines started before
the failure unwind. Teardown order is unchanged.
2026-06-16 15:14:47 +02:00
Zoltán Papp
3190347849 [client] abort Start if context cancelled while waiting for signal stream
receiveSignalEvents blocks in WaitStreamConnected until the signal stream
connects or the context is cancelled. If Stop cancelled e.ctx while Start
was parked there, Start kept going: it started the remaining subsystems on
a cancelled context and marked a shutting-down engine as started. Return
the context error from receiveSignalEvents and propagate it from Start, so
the deferred cleanup runs and the cancellation reaches the caller.
2026-06-16 15:03:29 +02:00
Zoltán Papp
90af9dd8ae [client] fix WaitStreamConnected stale-channel race
The StreamConnected check and the wait-channel creation took the mutex
separately, so notifyStreamConnected could set the status and close/clear
connectedCh in between: the waiter then created a fresh channel nobody
would ever close and blocked forever. Also, the status read was unlocked
while notify wrote it under the mutex (a data race). Do the check and the
channel fetch in one locked section; drop the now-unused
getStreamStatusChan helper. Pre-existing bug, not introduced by this branch.
2026-06-16 14:51:17 +02:00
Zoltán Papp
5cf865b243 [client] bound WaitStreamConnected in signal client tests
The tests waited on WaitStreamConnected with context.Background() and the
client's own context was also Background, so a stream that never connects
would hang until the suite timeout. Pass a 5s timeout context and assert
StreamConnected afterwards so the tests fail fast with a clear reason.
2026-06-16 14:47:07 +02:00
Zoltán Papp
67b362b4a4 [client] interrupt connect backoff on context cancel
The run loop retried with a raw ExponentialBackOff, so a backoff sleep
ignored context cancellation. Now that ConnectClient.Stop waits for the
run loop to exit, a cancel landing during a sleep would block Stop for
the full interval (up to MaxInterval). Wrap the backoff with the run
context so Retry returns promptly on cancel; the retry budget itself
(MaxElapsedTime) is unchanged.
2026-06-16 14:42:35 +02:00
Zoltán Papp
32fccdeede [client] init context state in engine tests
Engine tests built the engine context with context.WithCancel(
context.Background()), omitting CtxInitState. Now that the run context
is created in the constructor, the wgIfaceMonitor goroutine can reach
triggerClientRestart during teardown, which calls CtxGetState and
panics on the missing state. Real entry points (up, embed, service)
always CtxInitState; only the tests skipped it.
2026-06-16 14:35:44 +02:00
Zoltán Papp
98c71d7913 [client] fix Start/Stop race by making the run loop own engine shutdown
ConnectClient.Stop stopped the engine directly while the run loop's
backoff cycle could still be starting an engine, so Engine.close raced
Engine.Start (e.g. firewall setup reading wgInterface while close nils
it). embed.Client.Start's rollback only avoided a deadlock by cancelling
before Stop; the race itself remained and was caught by -race.

Make the run loop the sole owner of engine shutdown: derive the run
context in NewConnectClient, and have Stop cancel it and wait for the
loop to exit (skipping the wait when the loop never ran) instead of
calling engine.Stop. The loop now always stops the engine on its way
out, dropping the unsynchronised wgInterface check it used to guard that
call. Self-calls from within the loop use runCancel to avoid waiting on
themselves.

embed keeps a defensive pre-Stop cancel(); the daemon's cleanupConnection
gets a TODO to adopt Stop() rather than stopping the engine in parallel.
2026-06-16 13:57:41 +02:00
Zoltán Papp
002e0b036f [client] let engine context unblock WaitStreamConnected
WaitStreamConnected only watched the signal client's own context, which
derives from the parent engineCtx rather than the engine's run context.
A Start blocked here (signal stream not yet up) could therefore not be
released by Engine.Stop, since Stop only cancels the engine's run
context.

Pass a context into WaitStreamConnected and select on it too, and have
the engine pass e.ctx, so Stop cancelling e.ctx unblocks a parked Start.
Update the Client interface, the mock, and callers accordingly.
2026-06-16 13:11:46 +02:00
Zoltán Papp
c370c72d93 [client] make Engine single-use and guard against double Start
Create the run context once in NewEngine instead of in Start. This
keeps e.cancel valid for the engine's whole lifetime, so Stop can
cancel a Start that is blocked waiting on the network while holding
syncMsgMux: Stop now cancels before taking the lock, unblocking that
Start so it can release the mutex.

Reject re-entry into Start: a non-nil wgInterface means a prior Start
already ran (ErrEngineAlreadyStarted), and a cancelled run context
means the engine was stopped (ErrEngineAlreadyStopped). Both checks run
before the cleanup defer so a duplicate call cannot tear down the
running engine's state.
2026-06-16 13:11:46 +02:00
Zoltán Papp
5895a39380 [client] always clean up on Engine.Start failure via defer
The rosenpass init paths (NewManager/Run) returned without calling
e.close(), leaking the WireGuard interface and other partially
initialized state on failure. Per-branch cleanup was easy to miss when
adding new early returns.

Convert Start to a named error return and tear down via a single defer
that calls e.close() whenever err != nil, removing the scattered
per-branch close() calls (including the redundant one in initFirewall).
2026-06-16 11:54:08 +02:00
26 changed files with 317 additions and 1325 deletions

View File

@@ -279,9 +279,11 @@ func (c *Client) Start(startCtx context.Context) error {
select {
case <-startCtx.Done():
// Cancel the client context before stopping: Engine.Start blocks on the
// signal stream while holding the engine mutex and only unblocks on
// cancellation. Stopping first would deadlock on that mutex.
// ConnectClient.Stop now cancels its own run context and waits for the
// run loop to tear the engine down, so this cancel() is no longer
// required to break the deadlock and could be removed. It is kept as a
// defensive belt-and-suspenders: cancelling the parent context first
// guarantees the run loop is unblocked even if Stop's contract regresses.
cancel()
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())

View File

@@ -11,6 +11,7 @@ import (
"runtime/debug"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/cenkalti/backoff/v4"
@@ -54,6 +55,10 @@ var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath
type ConnectClient struct {
ctx context.Context
runCancel context.CancelFunc
runExited chan struct{}
runOnce sync.Once
runStarted atomic.Bool
config *profilemanager.Config
statusRecorder *peer.Status
@@ -70,8 +75,14 @@ func NewConnectClient(
config *profilemanager.Config,
statusRecorder *peer.Status,
) *ConnectClient {
// Derive the run context here so Stop owns the cancel that unblocks the run
// loop. runCancel is set once at construction, so Stop can call it without
// racing the run loop's startup. Callers therefore need not cancel before Stop.
runCtx, runCancel := context.WithCancel(ctx)
return &ConnectClient{
ctx: ctx,
ctx: runCtx,
runCancel: runCancel,
runExited: make(chan struct{}),
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
@@ -118,8 +129,6 @@ func (c *ConnectClient) RunOniOS(
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
stateFilePath string,
cacheDir string,
logFilePath string,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
debug.SetGCPercent(5)
@@ -129,12 +138,16 @@ func (c *ConnectClient) RunOniOS(
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
StateFilePath: stateFilePath,
TempDir: cacheDir,
}
return c.run(mobileDependency, nil, logFilePath)
return c.run(mobileDependency, nil, "")
}
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
// Mark the loop as started and signal exit on return so Stop can wait for
// the loop to finish (and skip the wait if the loop never ran).
c.runStarted.Store(true)
defer c.runOnce.Do(func() { close(c.runExited) })
defer func() {
if r := recover(); r != nil {
rec := c.statusRecorder
@@ -290,7 +303,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin)
_ = c.Stop()
c.runCancel()
return backoff.Permanent(wrapErr(err)) // unrecoverable error
}
return wrapErr(err)
@@ -410,14 +423,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.engine = nil
c.engineMutex.Unlock()
// todo: consider to remove this condition. Is not thread safe.
// We should always call Stop(), but we need to verify that it is idempotent
if engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
log.Infof("ensuring wg interface is removed, Netbird engine context cancelled")
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
c.statusRecorder.ClientTeardown()
@@ -433,12 +442,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
c.statusRecorder.ClientStart()
err = backoff.Retry(operation, backOff)
err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx))
if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin)
_ = c.Stop()
c.runCancel()
}
return err
}
@@ -516,11 +525,9 @@ func (c *ConnectClient) Status() StatusType {
}
func (c *ConnectClient) Stop() error {
engine := c.Engine()
if engine != nil {
if err := engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
c.runCancel()
if c.runStarted.Load() {
<-c.runExited
}
return nil
}

View File

@@ -250,7 +250,6 @@ type BundleGenerator struct {
syncResponse *mgmProto.SyncResponse
logPath string
tempDir string
statePath string
cpuProfile []byte
capturePath string
refreshStatus func() // Optional callback to refresh status before bundle generation
@@ -277,7 +276,6 @@ type GeneratorDependencies struct {
SyncResponse *mgmProto.SyncResponse
LogPath string
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
StatePath string // Path to the state file. If empty, the ServiceManager default path is used.
CPUProfile []byte
CapturePath string
RefreshStatus func()
@@ -301,7 +299,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
syncResponse: deps.SyncResponse,
logPath: deps.LogPath,
tempDir: deps.TempDir,
statePath: deps.StatePath,
cpuProfile: deps.CPUProfile,
capturePath: deps.CapturePath,
refreshStatus: deps.RefreshStatus,
@@ -853,11 +850,8 @@ func (g *BundleGenerator) maskSecrets() {
}
func (g *BundleGenerator) addStateFile() error {
path := g.statePath
if path == "" {
sm := profilemanager.NewServiceManager("")
path = sm.GetStatePath()
}
sm := profilemanager.NewServiceManager("")
path := sm.GetStatePath()
if path == "" {
return nil
}

View File

@@ -1,36 +0,0 @@
//go:build ios
package debug
import (
"path/filepath"
log "github.com/sirupsen/logrus"
)
// swiftLogFile is the Swift app log written by the iOS app into the same log
// directory as the Go client log, so it can be collected into the bundle.
const swiftLogFile = "swift-log.log"
// addPlatformLog collects logs for the iOS debug bundle. iOS has no logcat or
// systemd journal, so we rely on file-based logs. addLogfile handles the Go
// client log (logPath) with rotation, the stderr/stdout companions and
// anonymization. The iOS app writes its own Swift log into the same directory,
// so we add it alongside the Go log.
func (g *BundleGenerator) addPlatformLog() error {
if err := g.addLogfile(); err != nil {
return err
}
if g.logPath == "" {
return nil
}
swiftLogPath := filepath.Join(filepath.Dir(g.logPath), swiftLogFile)
if err := g.addSingleLogfile(swiftLogPath, swiftLogFile); err != nil {
// The Swift log is best-effort: the app may not have written it yet.
log.Warnf("failed to add %s to debug bundle: %v", swiftLogFile, err)
}
return nil
}

View File

@@ -1,4 +1,4 @@
//go:build !android && !ios
//go:build !android
package debug

View File

@@ -8,7 +8,6 @@ import (
"errors"
"net"
"net/netip"
"slices"
"strings"
"github.com/miekg/dns"
@@ -168,10 +167,7 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
case dns.TypeA:
alternativeNetwork = "ip6"
default:
// Non-address types reach LookupIP only unexpectedly; without an
// address pair to probe we cannot prove the name is absent, so answer
// NODATA rather than a poisoning NXDOMAIN.
return dns.RcodeSuccess
return dns.RcodeNameError
}
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
@@ -188,230 +184,6 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
return dns.RcodeSuccess
}
// RecordResolver is the host resolver surface used to forward non-address
// record queries. net.DefaultResolver satisfies it.
type RecordResolver interface {
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
LookupTXT(ctx context.Context, name string) ([]string, error)
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
LookupCNAME(ctx context.Context, host string) (string, error)
LookupAddr(ctx context.Context, addr string) ([]string, error)
}
// LookupRecords resolves a non-address DNS record type through the host
// resolver and returns the resource records and the DNS rcode. Types the host
// resolver cannot answer (anything not covered by the net.Resolver Lookup*
// methods) yield NODATA so that a routed name is never poisoned with NXDOMAIN
// for an unsupported type.
func LookupRecords(ctx context.Context, r RecordResolver, name string, qtype uint16, ttl uint32) ([]dns.RR, int) {
fqdn := dns.Fqdn(name)
switch qtype {
case dns.TypeMX:
return lookupMX(ctx, r, name, fqdn, ttl)
case dns.TypeTXT:
return lookupTXT(ctx, r, name, fqdn, ttl)
case dns.TypeNS:
return lookupNS(ctx, r, name, fqdn, ttl)
case dns.TypeSRV:
return lookupSRV(ctx, r, name, fqdn, ttl)
case dns.TypeCNAME:
return lookupCNAME(ctx, r, name, fqdn, ttl)
case dns.TypePTR:
return lookupPTR(ctx, r, name, fqdn, ttl)
default:
return nil, dns.RcodeSuccess
}
}
func recordHeader(fqdn string, rrtype uint16, ttl uint32) dns.RR_Header {
return dns.RR_Header{Name: fqdn, Rrtype: rrtype, Class: dns.ClassINET, Ttl: ttl}
}
func lookupMX(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupMX(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, mx := range recs {
rrs = append(rrs, &dns.MX{
Hdr: recordHeader(fqdn, dns.TypeMX, ttl),
Preference: mx.Pref,
Mx: dns.Fqdn(mx.Host),
})
}
return rrs, dns.RcodeSuccess
}
func lookupTXT(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupTXT(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, txt := range recs {
rrs = append(rrs, &dns.TXT{
Hdr: recordHeader(fqdn, dns.TypeTXT, ttl),
Txt: chunkTXT(txt),
})
}
return rrs, dns.RcodeSuccess
}
func lookupNS(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupNS(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, ns := range recs {
rrs = append(rrs, &dns.NS{
Hdr: recordHeader(fqdn, dns.TypeNS, ttl),
Ns: dns.Fqdn(ns.Host),
})
}
return rrs, dns.RcodeSuccess
}
func lookupSRV(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
_, recs, err := r.LookupSRV(ctx, "", "", name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, srv := range recs {
rrs = append(rrs, &dns.SRV{
Hdr: recordHeader(fqdn, dns.TypeSRV, ttl),
Priority: srv.Priority,
Weight: srv.Weight,
Port: srv.Port,
Target: dns.Fqdn(srv.Target),
})
}
return rrs, dns.RcodeSuccess
}
func lookupCNAME(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
cname, err := r.LookupCNAME(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
// LookupCNAME returns the queried name itself when the name resolves but
// has no CNAME record; that is a NODATA result, not a CNAME.
if strings.EqualFold(dns.Fqdn(cname), fqdn) {
return nil, dns.RcodeSuccess
}
return []dns.RR{&dns.CNAME{
Hdr: recordHeader(fqdn, dns.TypeCNAME, ttl),
Target: dns.Fqdn(cname),
}}, dns.RcodeSuccess
}
func lookupPTR(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
addr, ok := ptrQueryAddr(name)
if !ok {
return nil, dns.RcodeSuccess
}
names, err := r.LookupAddr(ctx, addr)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(names))
for _, n := range names {
rrs = append(rrs, &dns.PTR{
Hdr: recordHeader(fqdn, dns.TypePTR, ttl),
Ptr: dns.Fqdn(n),
})
}
return rrs, dns.RcodeSuccess
}
// ptrQueryAddr converts a reverse-DNS query name (in-addr.arpa or ip6.arpa)
// into the address string expected by net.Resolver.LookupAddr. It reports false
// when the name is not a well-formed reverse name.
func ptrQueryAddr(qname string) (string, bool) {
name := strings.TrimSuffix(strings.ToLower(dns.Fqdn(qname)), ".")
switch {
case strings.HasSuffix(name, ".in-addr.arpa"):
return parseInAddrArpa(strings.TrimSuffix(name, ".in-addr.arpa"))
case strings.HasSuffix(name, ".ip6.arpa"):
return parseIP6Arpa(strings.TrimSuffix(name, ".ip6.arpa"))
default:
return "", false
}
}
// parseInAddrArpa turns the label portion of an in-addr.arpa name into an IPv4
// address string, reporting false when it is not a well-formed reverse name.
func parseInAddrArpa(labelPart string) (string, bool) {
labels := strings.Split(labelPart, ".")
if len(labels) != 4 {
return "", false
}
slices.Reverse(labels)
addr, err := netip.ParseAddr(strings.Join(labels, "."))
if err != nil || !addr.Is4() {
return "", false
}
return addr.String(), true
}
// parseIP6Arpa turns the nibble portion of an ip6.arpa name into an IPv6
// address string, reporting false when it is not a well-formed reverse name.
func parseIP6Arpa(nibblePart string) (string, bool) {
nibbles := strings.Split(nibblePart, ".")
if len(nibbles) != 32 {
return "", false
}
slices.Reverse(nibbles)
var sb strings.Builder
for i, n := range nibbles {
if i > 0 && i%4 == 0 {
sb.WriteByte(':')
}
sb.WriteString(n)
}
addr, err := netip.ParseAddr(sb.String())
if err != nil || !addr.Is6() {
return "", false
}
return addr.String(), true
}
// rcodeForRecordError maps a non-address lookup error to a DNS rcode. A
// not-found result becomes NODATA rather than NXDOMAIN: net.DNSError.IsNotFound
// does not distinguish a missing name from a name that exists only with records
// of other types, so the name cannot be proven absent and must not be poisoned.
func rcodeForRecordError(err error) int {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
return dns.RcodeSuccess
}
return dns.RcodeServerFailure
}
// chunkTXT splits a TXT string into character-strings no longer than 255 bytes
// so the record can be packed. The chunks form one TXT resource record.
func chunkTXT(s string) []string {
const maxLen = 255
if len(s) <= maxLen {
return []string{s}
}
var chunks []string
for len(s) > maxLen {
chunks = append(chunks, s[:maxLen])
s = s[maxLen:]
}
if len(s) > 0 {
chunks = append(chunks, s)
}
return chunks
}
// FormatAnswers formats DNS resource records for logging.
func FormatAnswers(answers []dns.RR) string {
if len(answers) == 0 {

View File

@@ -5,7 +5,6 @@ import (
"errors"
"net"
"net/netip"
"strings"
"testing"
"github.com/miekg/dns"
@@ -121,161 +120,3 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
}
func TestPtrQueryAddr(t *testing.T) {
tests := []struct {
name string
qname string
want string
wantOK bool
}{
{name: "ipv4", qname: "4.3.2.1.in-addr.arpa.", want: "1.2.3.4", wantOK: true},
{name: "ipv4 no trailing dot", qname: "1.0.0.127.in-addr.arpa", want: "127.0.0.1", wantOK: true},
{
name: "ipv6",
qname: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
want: "2001:db8::1",
wantOK: true,
},
{name: "ipv4 wrong label count", qname: "2.1.in-addr.arpa.", wantOK: false},
{name: "ipv6 wrong nibble count", qname: "1.0.ip6.arpa.", wantOK: false},
{name: "not a reverse name", qname: "example.com.", wantOK: false},
{name: "ipv4 bad octet", qname: "4.3.2.999.in-addr.arpa.", wantOK: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := ptrQueryAddr(tt.qname)
assert.Equal(t, tt.wantOK, ok, "parse success mismatch")
if tt.wantOK {
assert.Equal(t, tt.want, got, "parsed address mismatch")
}
})
}
}
type mockRecordResolver struct {
mx []*net.MX
txt []string
ns []*net.NS
srv []*net.SRV
cname string
ptr []string
err error
}
func (m *mockRecordResolver) LookupMX(context.Context, string) ([]*net.MX, error) {
return m.mx, m.err
}
func (m *mockRecordResolver) LookupTXT(context.Context, string) ([]string, error) {
return m.txt, m.err
}
func (m *mockRecordResolver) LookupNS(context.Context, string) ([]*net.NS, error) {
return m.ns, m.err
}
func (m *mockRecordResolver) LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error) {
return "", m.srv, m.err
}
func (m *mockRecordResolver) LookupCNAME(context.Context, string) (string, error) {
return m.cname, m.err
}
func (m *mockRecordResolver) LookupAddr(context.Context, string) ([]string, error) {
return m.ptr, m.err
}
func TestLookupRecords(t *testing.T) {
notFound := &net.DNSError{IsNotFound: true, Name: "example.com."}
t.Run("MX success", func(t *testing.T) {
r := &mockRecordResolver{mx: []*net.MX{{Host: "mail.example.com.", Pref: 10}}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "mail.example.com.", rrs[0].(*dns.MX).Mx)
})
t.Run("TXT short string is one character-string", func(t *testing.T) {
r := &mockRecordResolver{txt: []string{"v=spf1 -all"}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, []string{"v=spf1 -all"}, rrs[0].(*dns.TXT).Txt)
})
t.Run("TXT chunks long strings", func(t *testing.T) {
long := strings.Repeat("a", 300)
r := &mockRecordResolver{txt: []string{long}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
txt := rrs[0].(*dns.TXT).Txt
require.Len(t, txt, 2, "300-byte string should split into two character-strings")
assert.Equal(t, 255, len(txt[0]))
assert.Equal(t, 45, len(txt[1]))
})
t.Run("NS success", func(t *testing.T) {
r := &mockRecordResolver{ns: []*net.NS{{Host: "ns1.example.com."}}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeNS, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "ns1.example.com.", rrs[0].(*dns.NS).Ns)
})
t.Run("SRV success", func(t *testing.T) {
r := &mockRecordResolver{srv: []*net.SRV{{Target: "sip.example.com.", Port: 5060}}}
rrs, rcode := LookupRecords(context.Background(), r, "_sip._tcp.example.com.", dns.TypeSRV, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, uint16(5060), rrs[0].(*dns.SRV).Port)
})
t.Run("CNAME success", func(t *testing.T) {
r := &mockRecordResolver{cname: "target.example.com."}
rrs, rcode := LookupRecords(context.Background(), r, "www.example.com.", dns.TypeCNAME, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "target.example.com.", rrs[0].(*dns.CNAME).Target)
})
t.Run("CNAME equal to name is NODATA", func(t *testing.T) {
r := &mockRecordResolver{cname: "example.com."}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCNAME, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs, "self-referential CNAME is NODATA")
})
t.Run("PTR success", func(t *testing.T) {
r := &mockRecordResolver{ptr: []string{"host.example.com."}}
rrs, rcode := LookupRecords(context.Background(), r, "4.3.2.1.in-addr.arpa.", dns.TypePTR, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "host.example.com.", rrs[0].(*dns.PTR).Ptr)
})
t.Run("PTR malformed name is NODATA", func(t *testing.T) {
r := &mockRecordResolver{}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypePTR, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs)
})
t.Run("not found is NODATA never NXDOMAIN", func(t *testing.T) {
r := &mockRecordResolver{err: notFound}
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeSuccess, rcode, "missing record must not poison the name")
})
t.Run("server failure maps to SERVFAIL", func(t *testing.T) {
r := &mockRecordResolver{err: &net.DNSError{Err: "server misbehaving", IsTemporary: true}}
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeServerFailure, rcode)
})
t.Run("unsupported type is NODATA", func(t *testing.T) {
r := &mockRecordResolver{}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCAA, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs)
})
}

View File

@@ -28,12 +28,6 @@ const upstreamTimeout = 15 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
LookupTXT(ctx context.Context, name string) ([]string, error)
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
LookupCNAME(ctx context.Context, host string) (string, error)
LookupAddr(ctx context.Context, addr string) ([]string, error)
}
type firewaller interface {
@@ -207,6 +201,12 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
resp := query.SetReply(query)
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
resp.Rcode = dns.RcodeNotImplemented
f.writeResponse(logger, w, resp, qname, startTime)
return
}
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
if mostSpecificResId == "" {
@@ -218,40 +218,6 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
switch question.Qtype {
case dns.TypeA, dns.TypeAAAA:
f.handleAddressQuery(ctx, logger, w, resp, mostSpecificResId, matchingEntries, startTime)
case dns.TypeMX, dns.TypeTXT, dns.TypeNS, dns.TypeSRV, dns.TypeCNAME, dns.TypePTR:
f.handleRecordQuery(ctx, logger, w, resp, startTime)
default:
// The domain is routed here, so any other type is answered NODATA
// (NOERROR, empty answer) rather than falling back to a resolver that
// would poison the name with NXDOMAIN. The Extended DNS Error lets a
// client tell this capability-driven NODATA apart from an
// authoritative one. The OPT pseudo-record must not appear unless the
// query advertised EDNS0.
if query.IsEdns0() != nil {
attachEDE(resp, dns.ExtendedErrorCodeNotSupported, "netbird forwarder: unsupported query type")
}
f.writeResponse(logger, w, resp, qname, startTime)
}
}
// handleAddressQuery resolves A/AAAA queries, programs the firewall sets and
// resolved-IP state, and caches the answer for resilience on upstream failure.
func (f *DNSForwarder) handleAddressQuery(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
resp *dns.Msg,
mostSpecificResId route.ResID,
matchingEntries []*ForwarderEntry,
startTime time.Time,
) {
question := resp.Question[0]
qname := strings.ToLower(question.Name)
network := resutil.NetworkForQtype(question.Qtype)
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
@@ -265,25 +231,6 @@ func (f *DNSForwarder) handleAddressQuery(
f.writeResponse(logger, w, resp, qname, startTime)
}
// handleRecordQuery resolves non-address record types (MX, TXT, NS, SRV,
// CNAME, PTR) through the host resolver. Missing records are answered NODATA so
// the routed name is never poisoned with NXDOMAIN.
func (f *DNSForwarder) handleRecordQuery(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
resp *dns.Msg,
startTime time.Time,
) {
question := resp.Question[0]
qname := strings.ToLower(question.Name)
records, rcode := resutil.LookupRecords(ctx, f.resolver, qname, question.Qtype, f.ttl)
resp.Rcode = rcode
resp.Answer = append(resp.Answer, records...)
f.writeResponse(logger, w, resp, qname, startTime)
}
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
@@ -467,14 +414,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
return selectedResId, matches
}
// attachEDE adds an Extended DNS Error (RFC 8914) option to the response,
// creating the OPT pseudo-record if the response does not already carry one.
func attachEDE(resp *dns.Msg, code uint16, text string) {
opt := resp.IsEdns0()
if opt == nil {
resp.SetEdns0(dns.DefaultMsgSize, false)
opt = resp.IsEdns0()
}
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: code, ExtraText: text})
}

View File

@@ -132,41 +132,6 @@ func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([
return args.Get(0).([]netip.Addr), args.Error(1)
}
func (m *MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]*net.MX)
return recs, args.Error(1)
}
func (m *MockResolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]string)
return recs, args.Error(1)
}
func (m *MockResolver) LookupNS(ctx context.Context, name string) ([]*net.NS, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]*net.NS)
return recs, args.Error(1)
}
func (m *MockResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) {
args := m.Called(ctx, service, proto, name)
recs, _ := args.Get(1).([]*net.SRV)
return args.String(0), recs, args.Error(2)
}
func (m *MockResolver) LookupCNAME(ctx context.Context, host string) (string, error) {
args := m.Called(ctx, host)
return args.String(0), args.Error(1)
}
func (m *MockResolver) LookupAddr(ctx context.Context, addr string) ([]string, error) {
args := m.Called(ctx, addr)
recs, _ := args.Get(0).([]string)
return recs, args.Error(1)
}
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
tests := []struct {
name string
@@ -579,15 +544,12 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
}
func TestDNSForwarder_ResponseCodes(t *testing.T) {
// A type with no net.Resolver Lookup method (CAA) must answer NODATA
// (NOERROR, empty) rather than NXDOMAIN/NOTIMP to avoid poisoning the name.
tests := []struct {
name string
queryType uint16
queryDomain string
configured string
expectedCode int
expectEDE bool
description string
}{
{
@@ -599,13 +561,28 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
description: "RFC compliant REFUSED for unauthorized queries",
},
{
name: "unsupported query type returns NODATA",
queryType: dns.TypeCAA,
name: "unsupported query type returns NOTIMP",
queryType: dns.TypeMX,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeSuccess,
expectEDE: true,
description: "Unsupported types answer NODATA, not NXDOMAIN/NOTIMP",
expectedCode: dns.RcodeNotImplemented,
description: "RFC compliant NOTIMP for unsupported types",
},
{
name: "CNAME query returns NOTIMP",
queryType: dns.TypeCNAME,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "CNAME queries not supported",
},
{
name: "TXT query returns NOTIMP",
queryType: dns.TypeTXT,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "TXT queries not supported",
},
}
@@ -621,7 +598,6 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
query.SetEdns0(dns.DefaultMsgSize, false)
// Capture the written response
var writtenResp *dns.Msg
@@ -637,213 +613,10 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
assert.Empty(t, writtenResp.Answer, "Non-address response should carry no answers")
if tt.expectEDE {
require.NotNil(t, writtenResp.IsEdns0(), "EDNS0 client should get an OPT in the reply")
assert.True(t, hasEDE(writtenResp, dns.ExtendedErrorCodeNotSupported),
"unsupported type NODATA should carry EDE Not Supported")
}
})
}
}
func hasEDE(m *dns.Msg, code uint16) bool {
opt := m.IsEdns0()
if opt == nil {
return false
}
for _, o := range opt.Option {
if ede, ok := o.(*dns.EDNS0_EDE); ok && ede.InfoCode == code {
return true
}
}
return false
}
func TestDNSForwarder_RecordQueries(t *testing.T) {
notFound := &net.DNSError{IsNotFound: true, Name: "example.com"}
t.Run("MX records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupMX", mock.Anything, "example.com.").
Return([]*net.MX{{Host: "mail.example.com.", Pref: 10}}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
mx, ok := resp.Answer[0].(*dns.MX)
require.True(t, ok, "answer should be an MX record")
assert.Equal(t, uint16(10), mx.Preference)
assert.Equal(t, "mail.example.com.", mx.Mx)
mockResolver.AssertExpectations(t)
})
t.Run("missing MX is NODATA not NXDOMAIN", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
// A not-found cannot prove the name is absent (it may exist with only
// other record types), so it must answer NODATA, never NXDOMAIN.
mockResolver.On("LookupMX", mock.Anything, "example.com.").
Return(nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "missing record must be NODATA")
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("NS records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupNS", mock.Anything, "example.com.").
Return([]*net.NS{{Host: "ns1.example.com."}}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ns, ok := resp.Answer[0].(*dns.NS)
require.True(t, ok, "answer should be an NS record")
assert.Equal(t, "ns1.example.com.", ns.Ns)
mockResolver.AssertExpectations(t)
})
t.Run("missing NS is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupNS", mock.Anything, "example.com.").
Return(nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("SRV records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
Return("", []*net.SRV{{Target: "sip.example.com.", Port: 5060, Priority: 10, Weight: 5}}, nil).Once()
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
srv, ok := resp.Answer[0].(*dns.SRV)
require.True(t, ok, "answer should be an SRV record")
assert.Equal(t, "sip.example.com.", srv.Target)
assert.Equal(t, uint16(5060), srv.Port)
assert.Equal(t, uint16(10), srv.Priority)
mockResolver.AssertExpectations(t)
})
t.Run("missing SRV is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
Return("", nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("TXT records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupTXT", mock.Anything, "example.com.").
Return([]string{"v=spf1 -all"}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeTXT)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
txt, ok := resp.Answer[0].(*dns.TXT)
require.True(t, ok, "answer should be a TXT record")
assert.Equal(t, []string{"v=spf1 -all"}, txt.Txt)
mockResolver.AssertExpectations(t)
})
t.Run("CNAME record is forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "www.example.com")
mockResolver.On("LookupCNAME", mock.Anything, "www.example.com.").
Return("target.example.com.", nil).Once()
resp := runRecordQuery(t, forwarder, "www.example.com", dns.TypeCNAME)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok, "answer should be a CNAME record")
assert.Equal(t, "target.example.com.", cname.Target)
mockResolver.AssertExpectations(t)
})
t.Run("CNAME equal to the name is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
// No CNAME exists: LookupCNAME echoes the queried name back.
mockResolver.On("LookupCNAME", mock.Anything, "example.com.").
Return("example.com.", nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeCNAME)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer, "self-referential CNAME means no CNAME record")
mockResolver.AssertExpectations(t)
})
t.Run("PTR record is forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "*.in-addr.arpa")
// The reverse name is parsed back to the address LookupAddr expects.
mockResolver.On("LookupAddr", mock.Anything, "1.2.3.4").
Return([]string{"host.example.com."}, nil).Once()
resp := runRecordQuery(t, forwarder, "4.3.2.1.in-addr.arpa", dns.TypePTR)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok, "answer should be a PTR record")
assert.Equal(t, "host.example.com.", ptr.Ptr)
mockResolver.AssertExpectations(t)
})
}
func newRecordTestForwarder(t *testing.T, r resolver, configured string) *DNSForwarder {
t.Helper()
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = r
d, err := domain.FromString(configured)
require.NoError(t, err)
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
return forwarder
}
func runRecordQuery(t *testing.T, forwarder *DNSForwarder, qname string, qtype uint16) *dns.Msg {
t.Helper()
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(qname), qtype)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp, "expected response to be written")
return resp
}
func TestDNSForwarder_TCPTruncation(t *testing.T) {
// Test that large UDP responses are truncated with TC bit set
mockResolver := &MockResolver{}

View File

@@ -86,6 +86,8 @@ const (
var ErrResetConnection = fmt.Errorf("reset connection")
var ErrEngineAlreadyStarted = errors.New("engine already started")
type EngineConfig struct {
WgPort int
WgIfaceName string
@@ -199,6 +201,8 @@ type Engine struct {
ctx context.Context
cancel context.CancelFunc
started bool
wgInterface WGIface
udpMux *udpmux.UniversalUDPMuxDefault
@@ -279,9 +283,15 @@ func NewEngine(
services EngineServices,
mobileDep MobileDependency,
) *Engine {
// The engine is single-use: a fresh instance is built per connection
// cycle (see Client.run), so the run context is created once here rather
// than in Start.
ctx, cancel := context.WithCancel(clientCtx)
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
ctx: ctx,
cancel: cancel,
signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient,
@@ -314,8 +324,34 @@ func (e *Engine) Stop() error {
log.Debugf("tried stopping engine that is nil")
return nil
}
e.cancel()
e.syncMsgMux.Lock()
e.stopLocked()
e.syncMsgMux.Unlock()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}
log.Infof("stopped Netbird Engine")
return nil
}
// stopLocked tears down everything Start may have brought up, in the order
// teardown requires (DNS before the interface goes down, flow manager after).
// The caller must hold syncMsgMux. It is shared by Stop and by Start's failure
// path, so a partially-initialized engine is cleaned up the same way; every
// step is nil-guarded. It does not wait on shutdownWg — the caller does that
// after releasing the lock, since the goroutines also take syncMsgMux.
func (e *Engine) stopLocked() {
if e.connMgr != nil {
e.connMgr.Close()
}
@@ -366,10 +402,6 @@ func (e *Engine) Stop() error {
// so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.cancel != nil {
e.cancel()
}
e.jobExecutorWG.Wait() // block until job goroutines finish
e.close()
@@ -388,21 +420,6 @@ func (e *Engine) Stop() error {
if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
e.syncMsgMux.Unlock()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}
log.Infof("stopped Netbird Engine")
return nil
}
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
@@ -440,18 +457,38 @@ func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) (err error) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if err := iface.ValidateMTU(e.config.MTU); err != nil {
// The engine is single-use. Reject a duplicate start and a start on an
// already-stopped engine (run context cancelled).
if e.started {
return ErrEngineAlreadyStarted
}
if ctxErr := e.ctx.Err(); ctxErr != nil {
return fmt.Errorf("engine already stopped: %w", ctxErr)
}
e.started = true
// Tear down any partially-initialized state on a failed start. Cancel the
// run context first so goroutines started before the failure (connMgr,
// srWatcher, monitors) unwind, then stopLocked mirrors Stop's teardown (we
// already hold syncMsgMux), cleaning up route/DNS/flow/state managers too,
// not just what close() covers.
defer func() {
if err != nil {
e.cancel()
e.stopLocked()
}
}()
if err = iface.ValidateMTU(e.config.MTU); err != nil {
return fmt.Errorf("invalid MTU configuration: %w", err)
}
if e.cancel != nil {
e.cancel()
}
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
wgIface, err := e.newWgIface()
@@ -485,13 +522,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
if err != nil {
e.close()
return fmt.Errorf("read initial settings: %w", err)
}
dnsServer, err := e.newDnsServer(dnsConfig)
if err != nil {
e.close()
return fmt.Errorf("create dns server: %w", err)
}
e.dnsServer = dnsServer
@@ -526,7 +561,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
if err = e.wgInterfaceCreate(); err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
e.close()
return fmt.Errorf("create wg interface: %w", err)
}
@@ -535,7 +569,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
}
if err := e.createFirewall(); err != nil {
e.close()
return err
}
@@ -547,7 +580,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.udpMux, err = e.wgInterface.Up()
if err != nil {
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
e.close()
return fmt.Errorf("up wg interface: %w", err)
}
@@ -572,9 +604,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.acl = acl.NewDefaultManager(e.firewall)
}
err = e.dnsServer.Initialize()
if err != nil {
e.close()
if err := e.dnsServer.Initialize(); err != nil {
return fmt.Errorf("initialize dns server: %w", err)
}
@@ -586,7 +616,9 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start(peer.IsForceRelayed())
e.receiveSignalEvents()
if err = e.receiveSignalEvents(); err != nil {
return err
}
e.receiveManagementEvents()
e.receiveJobEvents()
@@ -638,7 +670,6 @@ func (e *Engine) createFirewall() error {
func (e *Engine) initFirewall() error {
if err := e.routeManager.SetFirewall(e.firewall); err != nil {
e.close()
return fmt.Errorf("set firewall: %w", err)
}
@@ -1698,7 +1729,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
}
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
func (e *Engine) receiveSignalEvents() {
func (e *Engine) receiveSignalEvents() error {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
@@ -1762,7 +1793,12 @@ func (e *Engine) receiveSignalEvents() {
}
}()
e.signal.WaitStreamConnected()
// todo: consider to remove this blocker. I do not see benefit to block the Start operations
e.signal.WaitStreamConnected(e.ctx)
if err := e.ctx.Err(); err != nil {
return fmt.Errorf("wait for signal stream: %w", err)
}
return nil
}
func (e *Engine) parseNATExternalIPMappings() []string {

View File

@@ -247,7 +247,7 @@ func TestEngine_SSH(t *testing.T) {
return
}
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
@@ -426,7 +426,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
return
}
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
@@ -638,7 +638,7 @@ func TestEngine_Sync(t *testing.T) {
return
}
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
// feed updates to Engine via mocked Management client
@@ -817,7 +817,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
return
}
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
@@ -1024,7 +1024,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
return
}
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
wgIfaceName := fmt.Sprintf("utun%d", 104+n)

View File

@@ -226,11 +226,12 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
// All query types for an intercepted domain are forwarded to the peer's
// DNS forwarder, which owns the name. Falling through to the system
// resolver would let it answer NXDOMAIN for a name it isn't authoritative
// for, poisoning the whole name (including the A/AAAA records the route
// does serve). The forwarder answers NODATA for types it cannot resolve.
// pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
d.continueToNextHandler(w, r, logger, "non A/AAAA query")
return
}
d.mu.RLock()
peerKey := d.currentPeerKey
d.mu.RUnlock()
@@ -277,6 +278,19 @@ func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger
}
}
// continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed writing DNS continue response: %v", err)
}
}
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
if !exists {

View File

@@ -9,7 +9,6 @@ import (
"net/url"
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
@@ -701,8 +700,6 @@ func resolveURLsToIPs(urls []string) []net.IP {
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
m.mirrorV6ExitPairSelections(clientRoutes)
// An explicit user "deselect all" must not be overridden by management auto-apply.
// Auto-applying an exit node here would call SelectRoutes, which clears the
// deselect-all flag and re-enables every route the user turned off.
@@ -719,24 +716,6 @@ func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HA
m.logExitNodeUpdate(exitNodeInfo)
}
// mirrorV6ExitPairSelections keeps every synthesized "-v6" exit route's selection
// consistent with its v4 base. The v4/v6 exit pair is a single toggle, so the v6
// entry always follows the base: deselecting the v4 exit node also drops its ::/0
// pair, and any stale (orphaned) explicit selection on the v6 entry is reset. This
// runs before selection is read so both collectExitNodeInfo and FilterSelectedExitNodes
// see consistent state, including pairs loaded from persisted selector state.
func (m *DefaultManager) mirrorV6ExitPairSelections(clientRoutes route.HAMap) {
routesByNetID := make(map[route.NetID][]*route.Route, len(clientRoutes))
for haID, routes := range clientRoutes {
routesByNetID[haID.NetID()] = routes
}
for v6ID := range route.V6ExitMergeSet(routesByNetID) {
baseID := route.NetID(strings.TrimSuffix(string(v6ID), route.V6ExitSuffix))
m.routeSelector.SyncPairedSelection(baseID, v6ID)
}
}
type exitNodeInfo struct {
allIDs []route.NetID
selectedByManagement []route.NetID

View File

@@ -1,47 +0,0 @@
package routemanager
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/route"
)
// TestUpdateRouteSelectorFromManagement_MirrorsV6ExitPair reproduces the bug seen
// in netbird-engine.log: persisted selector state has the v4 exit node deselected
// but its synthesized "-v6" pair explicitly selected (orphaned), so the ::/0 route
// leaked onto the tunnel. The management update must mirror the v4 deselect onto the
// v6 pair so FilterSelectedExitNodes drops it.
func TestUpdateRouteSelectorFromManagement_MirrorsV6ExitPair(t *testing.T) {
const (
v4ID = route.NetID("Exit Node (raspberrypi)")
v6ID = route.NetID("Exit Node (raspberrypi)-v6")
)
all := []route.NetID{v4ID, v6ID}
rs := routeselector.NewRouteSelector()
// Orphan the v6 selection: select the pair, then deselect only the v4 base.
require.NoError(t, rs.SelectRoutes([]route.NetID{v4ID, v6ID}, true, all))
require.NoError(t, rs.DeselectRoutes([]route.NetID{v4ID}, all))
require.True(t, rs.IsSelected(v6ID), "precondition: orphaned v6 selection survives v4 deselect")
m := &DefaultManager{routeSelector: rs}
v4Route := &route.Route{NetID: v4ID, Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: v6ID, Network: netip.MustParsePrefix("::/0")}
clientRoutes := route.HAMap{
"Exit Node (raspberrypi)|0.0.0.0/0": {v4Route},
"Exit Node (raspberrypi)-v6|::/0": {v6Route},
}
m.updateRouteSelectorFromManagement(clientRoutes)
assert.False(t, rs.IsSelected(v6ID), "v6 pair must follow the v4 base deselect after the management update")
filtered := rs.FilterSelectedExitNodes(clientRoutes)
assert.Empty(t, filtered, "deselected v4 exit node must not leak its ::/0 pair onto the tunnel")
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"slices"
"strings"
"sync"
"github.com/hashicorp/go-multierror"
@@ -131,33 +132,6 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
return rs.isSelectedLocked(routeID)
}
// SyncPairedSelection forces pairedID's explicit selection state to match baseID's,
// so a synthesized "-v6" exit route always follows its v4 base: selecting or
// deselecting the v4 exit node governs the ::/0 pair, and any stale (orphaned)
// explicit state on the v6 entry is reset. The v4/v6 exit pair is treated as a single
// toggle, so the v6 entry carries no independent selection of its own.
func (rs *RouteSelector) SyncPairedSelection(baseID, pairedID route.NetID) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.deselectAll {
return
}
_, baseSelected := rs.selectedRoutes[baseID]
_, baseDeselected := rs.deselectedRoutes[baseID]
delete(rs.selectedRoutes, pairedID)
delete(rs.deselectedRoutes, pairedID)
switch {
case baseSelected:
rs.selectedRoutes[pairedID] = struct{}{}
case baseDeselected:
rs.deselectedRoutes[pairedID] = struct{}{}
}
}
// FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
rs.mu.RLock()
@@ -177,13 +151,14 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
}
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route.
// The lookup is literal; v4/v6 exit pairs are kept consistent at write time via SyncPairedSelection,
// so a synthesized "-v6" entry carries the same explicit state as its v4 base.
// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of
// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing
// transparently apply to the synthesized v6 entry.
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
return rs.hasUserSelectionForRouteLocked(routeID)
return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID))
}
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
@@ -212,6 +187,83 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
return filtered
}
// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit
// state of its own, so selections made on the v4 entry govern the v6 entry automatically.
// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route
// would make it inherit unrelated v4 state. Must be called with rs.mu held.
func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID {
name := string(id)
if !strings.HasSuffix(name, route.V6ExitSuffix) {
return id
}
if _, ok := rs.selectedRoutes[id]; ok {
return id
}
if _, ok := rs.deselectedRoutes[id]; ok {
return id
}
return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix))
}
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
return !deselected
}
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
if rs.deselectAll {
return true
}
_, deselected := rs.deselectedRoutes[netID]
return deselected
}
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
}
func (rs *RouteSelector) applyExitNodeFilter(
id route.HAUniqueID,
netID route.NetID,
rt []*route.Route,
out route.HAMap,
) {
// Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also
// drops the synthesized v6 entry that lacks its own explicit state.
effective := rs.effectiveNetID(netID)
if rs.hasUserSelectionForRouteLocked(effective) {
if rs.isSelectedLocked(effective) {
out[id] = rt
}
return
}
// no explicit selection for this route: defer to management's SkipAutoApply flag
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {
if !r.SkipAutoApply {
sel = append(sel, r)
}
}
return sel
}
// MarshalJSON implements the json.Marshaler interface
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
rs.mu.RLock()
@@ -265,59 +317,3 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
return nil
}
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
return !deselected
}
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
if rs.deselectAll {
return true
}
_, deselected := rs.deselectedRoutes[netID]
return deselected
}
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func (rs *RouteSelector) applyExitNodeFilter(
id route.HAUniqueID,
netID route.NetID,
rt []*route.Route,
out route.HAMap,
) {
if rs.hasUserSelectionForRouteLocked(netID) {
if rs.isSelectedLocked(netID) {
out[id] = rt
}
return
}
// no explicit selection for this route: defer to management's SkipAutoApply flag
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {
if !r.SkipAutoApply {
sel = append(sel, r)
}
}
return sel
}

View File

@@ -330,73 +330,39 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) {
assert.Len(t, filtered, 0) // No routes should be selected
}
// TestRouteSelector_V6ExitPairSync covers SyncPairedSelection, which keeps a v4
// exit node and its synthesized "-v6" counterpart consistent. The selector itself
// is literal and never infers a v6 entry's state from its v4 base; callers that know
// the pairing (exit-node code paths) call SyncPairedSelection to force the v6 entry
// to follow the base, treating the pair as a single toggle.
func TestRouteSelector_V6ExitPairSync(t *testing.T) {
// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection
// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute
// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its
// v4 base, so legacy persisted selections that predate v6 pairing transparently
// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected)
// stay literal so unrelated routes named "*-v6" don't inherit unrelated state.
func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"}
t.Run("selector lookups stay literal without sync", func(t *testing.T) {
t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
// The selector does not pair-resolve: the v6 entry is independent until synced.
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 entry has no state of its own")
assert.True(t, rs.IsSelected("exit1-v6"), "unsynced v6 entry stays selected by default")
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection")
// A route literally named "exit1-something" must never pair-resolve either.
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
// unrelated v6 with no v4 base touched is unaffected
assert.False(t, rs.HasUserSelectionForRoute("exit2-v6"))
})
t.Run("sync mirrors deselected v4 base onto v6", func(t *testing.T) {
t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
// A non-exit route literally named "corp-v6" must not inherit "corp"'s state
// via the mirror; the mirror only applies in exit-node code paths.
assert.False(t, rs.IsSelected("corp"))
assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state")
})
t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.IsSelected("exit1"))
assert.False(t, rs.IsSelected("exit1-v6"), "v6 pair follows v4 base deselect")
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 carries explicit deselect after sync")
})
t.Run("sync mirrors selected v4 base onto v6", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1"}, false, all))
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.True(t, rs.IsSelected("exit1"))
assert.True(t, rs.IsSelected("exit1-v6"), "v6 pair follows v4 base select")
})
t.Run("sync clears v6 state when base has no explicit selection", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all))
require.True(t, rs.HasUserSelectionForRoute("exit1-v6"))
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"),
"v6 explicit state is cleared so it follows management like its base")
})
// Regression for the observed bug (see netbird-engine.log): persisted state has
// the v4 base deselected but the v6 sibling explicitly selected (orphaned). The
// sync must reset the orphan so the ::/0 route does not leak onto the tunnel.
t.Run("sync clears orphaned explicit v6 selection on deselected base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
// Prior state: both explicitly selected, then only the v4 base deselected,
// leaving the v6 entry as a stale explicit selection.
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1", "exit1-v6"}, true, all))
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
require.True(t, rs.IsSelected("exit1-v6"), "precondition: orphaned v6 selection")
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.IsSelected("exit1-v6"), "orphaned v6 selection reset to follow v4 deselect")
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
@@ -404,14 +370,23 @@ func TestRouteSelector_V6ExitPairSync(t *testing.T) {
"exit1|0.0.0.0/0": {v4Route},
"exit1-v6|::/0": {v6Route},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.Empty(t, filtered, "deselecting v4 base must drop the v6 pair even if it was explicitly selected before")
assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0"))
assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base")
})
t.Run("filter drops synced v6 pair of deselected v4 base", func(t *testing.T) {
t.Run("non-v6-suffix routes unaffected", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
// A route literally named "exit1-something" must not pair-resolve.
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
})
t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
rs.SyncPairedSelection("exit1", "exit1-v6")
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
@@ -424,15 +399,6 @@ func TestRouteSelector_V6ExitPairSync(t *testing.T) {
assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair")
})
t.Run("deselectAll makes sync a no-op", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
rs.DeselectAllRoutes()
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"), "sync must not write explicit state under deselectAll")
})
t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))

View File

@@ -17,7 +17,6 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -26,7 +25,6 @@ import (
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
types "github.com/netbirdio/netbird/upload-server/types"
)
// ConnectionListener export internal Listener for mobile
@@ -56,7 +54,6 @@ type selectRoute struct {
Network netip.Prefix
Domains domain.List
Selected bool
Status string
extraNetworks []netip.Prefix
}
@@ -68,8 +65,6 @@ func init() {
type Client struct {
cfgFile string
stateFile string
cacheDir string
logFilePath string
recorder *peer.Status
ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex
@@ -80,21 +75,16 @@ type Client struct {
onHostDnsFn func([]string)
dnsManager dns.IosDnsManager
loginComplete bool
connectClient *internal.ConnectClient
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
preloadedConfig *profilemanager.Config
stateMu sync.RWMutex
connectClient *internal.ConnectClient
config *profilemanager.Config
}
// NewClient instantiate a new Client
func NewClient(cfgFile, stateFile, cacheDir, logFilePath, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
return &Client{
cfgFile: cfgFile,
stateFile: stateFile,
cacheDir: cacheDir,
logFilePath: logFilePath,
deviceName: deviceName,
osName: osName,
osVersion: osVersion,
@@ -171,13 +161,8 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, connectClient)
// Persist the latest sync response so DebugBundle can include the network
// map. On iOS this is backed by disk to keep it out of the constrained
// process memory (see the syncstore package).
connectClient.SetSyncResponsePersistence(true)
return connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir, c.logFilePath)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
}
// Stop the internal client and free the resources
@@ -189,84 +174,6 @@ func (c *Client) Stop() {
}
c.ctxCancel()
c.setState(nil, nil)
}
// DebugBundle generates a debug bundle, uploads it and returns the upload key.
// It works with or without a running engine: when the engine is up it reuses
// the live config, sync response and client metrics; otherwise it loads the
// config from disk (or the preloaded tvOS config).
func (c *Client) DebugBundle(anonymize bool) (string, error) {
cfg, cc := c.stateSnapshot()
// If the engine hasn't been started, load config so we can reach management.
if cfg == nil {
if c.preloadedConfig != nil {
cfg = c.preloadedConfig
} else {
var err error
// Use DirectUpdateOrCreateConfig to avoid atomic file operations
// (temp file + rename) blocked by the tvOS sandbox.
cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
StateFilePath: c.stateFile,
})
if err != nil {
return "", fmt.Errorf("load config: %w", err)
}
}
}
deps := debug.GeneratorDependencies{
InternalConfig: cfg,
StatusRecorder: c.recorder,
TempDir: c.cacheDir,
StatePath: c.stateFile,
LogPath: c.logFilePath,
}
if cc != nil {
resp, err := cc.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
deps.SyncResponse = resp
if e := cc.Engine(); e != nil {
if cm := e.GetClientMetrics(); cm != nil {
deps.ClientMetrics = cm
}
}
}
bundleGenerator := debug.NewBundleGenerator(
deps,
debug.BundleConfig{
Anonymize: anonymize,
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
if err != nil {
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle uploaded with key %s", key)
return key, nil
}
// SetTraceLogLevel configure the logger to trace level
@@ -320,16 +227,6 @@ func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener()
}
// IsLoginRequiredCached reports whether the LAST observed management error was an
// auth failure (PermissionDenied/InvalidArgument), using the in-memory status
// recorder. Unlike IsLoginRequired() it performs NO network call, so it is safe to
// call from the connection listener during teardown (e.g. onDisconnected) without
// blocking on a slow or unavailable network. Returns false while connected to
// management or when the last error was not auth-related.
func (c *Client) IsLoginRequiredCached() bool {
return c.recorder.IsLoginRequired()
}
func (c *Client) IsLoginRequired() bool {
var ctx context.Context
//nolint
@@ -457,12 +354,11 @@ func (c *Client) ClearLoginComplete() {
}
func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
_, connectClient := c.stateSnapshot()
if connectClient == nil {
if c.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
engine := connectClient.Engine()
engine := c.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("not connected")
}
@@ -481,57 +377,9 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
routes := buildSelectRoutes(routesMap, routeSelector.IsSelected, v6ExitMerged)
resolvedDomains := c.recorder.GetResolvedDomainsStates()
// Compute each route's connection status in the core (mirroring the Android
// bridge), so the UI doesn't have to infer it by string-matching the joined
// Network value against peer routes. For a merged exit node the status reflects
// whichever of the v4/v6 prefixes is served by a connected peer; for dynamic
// (DNS) routes the peer route key is the domain pattern (see dynamic.Route.String).
connectedRoutes := c.connectedRouteSet()
for _, r := range routes {
r.Status = routeStatus(r, connectedRoutes)
}
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
}
// connectedRouteSet returns the set of route keys (as strings) currently served by a
// connected peer, gathered across all connected peers' route tables. The keys match
// what the route manager records: a prefix string for static routes (e.g. "0.0.0.0/0")
// and the domain pattern for dynamic routes (e.g. "*.example.com").
func (c *Client) connectedRouteSet() map[string]struct{} {
connected := map[string]struct{}{}
for _, p := range c.recorder.GetFullStatus().Peers {
if p.ConnStatus != peer.StatusConnected {
continue
}
for r := range p.GetRoutes() {
connected[r] = struct{}{}
}
}
return connected
}
// routeStatus reports "Connected" if any of the route's keys is served by a connected
// peer: the primary Network prefix, an extra v6 network of a merged exit node, or the
// domain pattern for a dynamic DNS route. Otherwise "Idle".
func routeStatus(r *selectRoute, connectedRoutes map[string]struct{}) string {
keys := make([]string, 0, 1+len(r.extraNetworks))
if len(r.Domains) > 0 {
keys = append(keys, r.Domains.SafeString())
} else {
keys = append(keys, r.Network.String())
}
for _, extra := range r.extraNetworks {
keys = append(keys, extra.String())
}
for _, k := range keys {
if _, ok := connectedRoutes[k]; ok {
return peer.StatusConnected.String()
}
}
return peer.StatusIdle.String()
}
func buildSelectRoutes(routesMap map[route.NetID][]*route.Route, isSelected func(route.NetID) bool, v6Merged map[route.NetID]struct{}) []*selectRoute {
var routes []*selectRoute
for id, rt := range routesMap {
@@ -614,7 +462,6 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
Network: netStr,
Domains: &domainDetails,
Selected: r.Selected,
Status: r.Status,
})
}
@@ -623,12 +470,11 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
}
func (c *Client) SelectRoute(id string) error {
_, connectClient := c.stateSnapshot()
if connectClient == nil {
if c.connectClient == nil {
return fmt.Errorf("not connected")
}
engine := connectClient.Engine()
engine := c.connectClient.Engine()
if engine == nil {
return fmt.Errorf("not connected")
}
@@ -654,11 +500,10 @@ func (c *Client) SelectRoute(id string) error {
}
func (c *Client) DeselectRoute(id string) error {
_, connectClient := c.stateSnapshot()
if connectClient == nil {
if c.connectClient == nil {
return fmt.Errorf("not connected")
}
engine := connectClient.Engine()
engine := c.connectClient.Engine()
if engine == nil {
return fmt.Errorf("not connected")
}
@@ -682,22 +527,6 @@ func (c *Client) DeselectRoute(id string) error {
return nil
}
// setState stores the running engine state so DebugBundle can reuse the live
// config and ConnectClient. It is cleared on Stop.
func (c *Client) setState(cfg *profilemanager.Config, cc *internal.ConnectClient) {
c.stateMu.Lock()
defer c.stateMu.Unlock()
c.config = cfg
c.connectClient = cc
}
// stateSnapshot returns the current config and ConnectClient under the lock.
func (c *Client) stateSnapshot() (*profilemanager.Config, *internal.ConnectClient) {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.config, c.connectClient
}
func formatDuration(d time.Duration) string {
ds := d.String()
dotIndex := strings.Index(ds, ".")

View File

@@ -20,7 +20,6 @@ type RoutesSelectionInfo struct {
Network string
Domains *DomainDetails
Selected bool
Status string
}
type DomainCollection interface {

View File

@@ -988,6 +988,10 @@ func (s *Server) cleanupConnection() error {
return nil
}
// TODO: consider calling s.connectClient.Stop() instead of engine.Stop().
// actCancel() lets the run loop stop the engine too, so both stop it
// concurrently; ConnectClient.Stop cancels and waits for the run loop,
// making the run loop the sole owner of engine shutdown.
if engine != nil {
if err := engine.Stop(); err != nil {
return err

View File

@@ -918,10 +918,6 @@ func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID strin
}
for _, svc := range services {
if err = transaction.DeleteServiceTargets(ctx, accountID, svc.ID); err != nil {
return fmt.Errorf("failed to delete service targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
@@ -1274,10 +1270,6 @@ func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serv
return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer")
}
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service: %w", err)
}
@@ -1327,10 +1319,6 @@ func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerI
return nil
}
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service: %w", err)
}

View File

@@ -458,9 +458,6 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newEphemeralService(), nil)
txMock.EXPECT().
DeleteServiceTargets(ctx, accountID, serviceID).
Return(nil)
txMock.EXPECT().
DeleteService(ctx, accountID, serviceID).
Return(nil)
@@ -563,9 +560,6 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newEphemeralService(), nil)
txMock.EXPECT().
DeleteServiceTargets(ctx, accountID, serviceID).
Return(nil)
txMock.EXPECT().
DeleteService(ctx, accountID, serviceID).
Return(nil)
@@ -610,9 +604,6 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newEphemeralService(), nil)
txMock.EXPECT().
DeleteServiceTargets(ctx, accountID, serviceID).
Return(nil)
txMock.EXPECT().
DeleteService(ctx, accountID, serviceID).
Return(nil)
@@ -1201,67 +1192,6 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
assert.Len(t, targets, 0, "All targets should be deleted when service is deleted")
}
func TestDeleteExpiredPeerService_DeletesTargets(t *testing.T) {
ctx := context.Background()
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
targets, err := testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
require.NoError(t, err)
require.Len(t, targets, 1, "ephemeral peer-exposed service should have exactly one persisted target before reaping")
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
err = mgr.deleteExpiredPeerService(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "expired peer-exposed service should be deleted")
s, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, status.NotFound, s.Type())
targets, err = testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
require.NoError(t, err)
assert.Len(t, targets, 0, "orphaned target rows must be deleted when an expired peer-exposed service is reaped")
}
func TestDeleteServiceFromPeer_DeletesTargets(t *testing.T) {
ctx := context.Background()
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
targets, err := testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
require.NoError(t, err)
require.Len(t, targets, 1, "ephemeral peer-exposed service should have exactly one persisted target before stopping")
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "stopped peer-exposed service should be deleted")
s, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, status.NotFound, s.Type())
targets, err = testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
require.NoError(t, err)
assert.Len(t, targets, 0, "orphaned target rows must be deleted when a peer stops its exposed service")
}
func TestValidateProtocolChange(t *testing.T) {
tests := []struct {
name string

View File

@@ -1989,7 +1989,7 @@ func (s *Server) addUDPRelay(ctx context.Context, mapping *proto.ProxyMapping, t
"service_id": svcID,
})
relay := udprelay.New(s.portRouterContext(ctx), udprelay.RelayConfig{
relay := udprelay.New(ctx, udprelay.RelayConfig{
Logger: entry,
Listener: listener,
Target: targetAddress,

View File

@@ -33,7 +33,7 @@ type Client interface {
Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error
Ready() bool
IsHealthy() bool
WaitStreamConnected()
WaitStreamConnected(context.Context)
SendToStream(msg *proto.EncryptedMessage) error
Send(msg *proto.Message) error
SetOnReconnectedListener(func())

View File

@@ -65,7 +65,10 @@ var _ = Describe("GrpcClient", func() {
return
}
}()
clientA.WaitStreamConnected()
ctxA, cancelA := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelA()
clientA.WaitStreamConnected(ctxA)
Expect(clientA.StreamConnected()).To(BeTrue())
// connect PeerB to Signal
keyB, _ := wgtypes.GenerateKey()
@@ -91,7 +94,10 @@ var _ = Describe("GrpcClient", func() {
}
}()
clientB.WaitStreamConnected()
ctxB, cancelB := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelB()
clientB.WaitStreamConnected(ctxB)
Expect(clientB.StreamConnected()).To(BeTrue())
// PeerA initiates ping-pong
err := clientA.Send(&sigProto.Message{
@@ -129,8 +135,10 @@ var _ = Describe("GrpcClient", func() {
return
}
}()
client.WaitStreamConnected()
Expect(client).NotTo(BeNil())
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
client.WaitStreamConnected(ctx)
Expect(client.StreamConnected()).To(BeTrue())
})
})

View File

@@ -213,15 +213,6 @@ func (c *GrpcClient) notifyStreamConnected() {
}
}
func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
c.mux.Lock()
defer c.mux.Unlock()
if c.connectedCh == nil {
c.connectedCh = make(chan struct{})
}
return c.connectedCh
}
func (c *GrpcClient) connect(ctx context.Context, key string) (proto.SignalExchange_ConnectStreamClient, error) {
c.stream = nil
@@ -282,14 +273,24 @@ func (c *GrpcClient) IsHealthy() bool {
}
// WaitStreamConnected waits until the client is connected to the Signal stream
func (c *GrpcClient) WaitStreamConnected() {
func (c *GrpcClient) WaitStreamConnected(ctx context.Context) {
// Check the status and obtain the wait channel atomically: otherwise
// notifyStreamConnected could flip the status and close/clear the channel
// between the check and the channel creation, leaving us waiting forever on
// a stale channel.
c.mux.Lock()
if c.status == StreamConnected {
c.mux.Unlock()
return
}
if c.connectedCh == nil {
c.connectedCh = make(chan struct{})
}
ch := c.connectedCh
c.mux.Unlock()
ch := c.getStreamStatusChan()
select {
case <-ctx.Done():
case <-c.ctx.Done():
case <-ch:
}

View File

@@ -55,7 +55,7 @@ func (sm *MockClient) Ready() bool {
return sm.ReadyFunc()
}
func (sm *MockClient) WaitStreamConnected() {
func (sm *MockClient) WaitStreamConnected(context.Context) {
if sm.WaitStreamConnectedFunc == nil {
return
}