[client] The status cmd will not be blocked by the ICE probe (#4597)

The status cmd will not be blocked by the ICE probe

Refactor the TURN and STUN probe, and cache the results. The NetBird status command will indicate a "checking…" state.
This commit is contained in:
Zoltan Papp
2025-10-28 16:11:35 +01:00
committed by GitHub
parent 404cab90ba
commit d7321c130b
6 changed files with 216 additions and 45 deletions

View File

@@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
client := proto.NewDaemonServiceClient(conn)
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true})
if err != nil {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
}
@@ -303,7 +303,7 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context())
statusResp, err := getStatus(cmd.Context(), true)
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {

View File

@@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(cmd.Context())
resp, err := getStatus(ctx)
resp, err := getStatus(ctx, false)
if err != nil {
return err
}
@@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
@@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}

View File

@@ -202,6 +202,8 @@ type Engine struct {
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
wgIfaceMonitorWg sync.WaitGroup
probeStunTurn *relay.StunTurnProbe
}
// Peer is an instance of the Connection Peer
@@ -244,6 +246,7 @@ func NewEngine(
statusRecorder: statusRecorder,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
}
sm := profilemanager.NewServiceManager("")
@@ -1663,7 +1666,7 @@ func (e *Engine) getRosenpassAddr() string {
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
// and updates the status recorder with the latest states.
func (e *Engine) RunHealthProbes() bool {
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
e.syncMsgMux.Lock()
signalHealthy := e.signal.IsHealthy()
@@ -1695,8 +1698,12 @@ func (e *Engine) RunHealthProbes() bool {
}
e.syncMsgMux.Unlock()
results := e.probeICE(stuns, turns)
var results []relay.ProbeResult
if waitForResult {
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
}
e.statusRecorder.UpdateRelayStates(results)
relayHealthy := true
@@ -1713,13 +1720,6 @@ func (e *Engine) RunHealthProbes() bool {
return allHealthy
}
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
return append(
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)...,
)
}
// restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() {
e.syncMsgMux.Lock()

View File

@@ -2,6 +2,8 @@ package relay
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"net"
"sync"
@@ -15,6 +17,15 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
DefaultCacheTTL = 20 * time.Second
probeTimeout = 6 * time.Second
)
var (
ErrCheckInProgress = errors.New("probe check is already in progress")
)
// ProbeResult holds the info about the result of a relay probe request
type ProbeResult struct {
URI string
@@ -22,8 +33,164 @@ type ProbeResult struct {
Addr string
}
type StunTurnProbe struct {
cacheResults []ProbeResult
cacheTimestamp time.Time
cacheKey string
cacheTTL time.Duration
probeInProgress bool
probeDone chan struct{}
mu sync.Mutex
}
func NewStunTurnProbe(cacheTTL time.Duration) *StunTurnProbe {
return &StunTurnProbe{
cacheTTL: cacheTTL,
}
}
func (p *StunTurnProbe) ProbeAllWaitResult(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
cacheKey := generateCacheKey(stuns, turns)
p.mu.Lock()
if p.probeInProgress {
doneChan := p.probeDone
p.mu.Unlock()
select {
case <-ctx.Done():
log.Debugf("Context cancelled while waiting for probe results")
return createErrorResults(stuns, turns)
case <-doneChan:
return p.getCachedResults(cacheKey, stuns, turns)
}
}
p.probeInProgress = true
probeDone := make(chan struct{})
p.probeDone = probeDone
p.mu.Unlock()
p.doProbe(ctx, stuns, turns, cacheKey)
close(probeDone)
return p.getCachedResults(cacheKey, stuns, turns)
}
// ProbeAll probes all given servers asynchronously and returns the results
func (p *StunTurnProbe) ProbeAll(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
cacheKey := generateCacheKey(stuns, turns)
p.mu.Lock()
if results := p.checkCache(cacheKey); results != nil {
p.mu.Unlock()
return results
}
if p.probeInProgress {
p.mu.Unlock()
return createErrorResults(stuns, turns)
}
p.probeInProgress = true
probeDone := make(chan struct{})
p.probeDone = probeDone
log.Infof("started new probe for STUN, TURN servers")
go func() {
p.doProbe(ctx, stuns, turns, cacheKey)
close(probeDone)
}()
p.mu.Unlock()
timer := time.NewTimer(1300 * time.Millisecond)
defer timer.Stop()
select {
case <-ctx.Done():
log.Debugf("Context cancelled while waiting for probe results")
return createErrorResults(stuns, turns)
case <-probeDone:
// when the probe is return fast, return the results right away
return p.getCachedResults(cacheKey, stuns, turns)
case <-timer.C:
// if the probe takes longer than 1.3s, return error results to avoid blocking
return createErrorResults(stuns, turns)
}
}
func (p *StunTurnProbe) checkCache(cacheKey string) []ProbeResult {
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
age := time.Since(p.cacheTimestamp)
if age < p.cacheTTL {
results := append([]ProbeResult(nil), p.cacheResults...)
log.Debugf("returning cached probe results (age: %v)", age)
return results
}
}
return nil
}
func (p *StunTurnProbe) getCachedResults(cacheKey string, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
p.mu.Lock()
defer p.mu.Unlock()
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
return append([]ProbeResult(nil), p.cacheResults...)
}
return createErrorResults(stuns, turns)
}
func (p *StunTurnProbe) doProbe(ctx context.Context, stuns []*stun.URI, turns []*stun.URI, cacheKey string) {
defer func() {
p.mu.Lock()
p.probeInProgress = false
p.mu.Unlock()
}()
results := make([]ProbeResult, len(stuns)+len(turns))
var wg sync.WaitGroup
for i, uri := range stuns {
wg.Add(1)
go func(idx int, stunURI *stun.URI) {
defer wg.Done()
probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
results[idx].URI = stunURI.String()
results[idx].Addr, results[idx].Err = p.probeSTUN(probeCtx, stunURI)
}(i, uri)
}
stunOffset := len(stuns)
for i, uri := range turns {
wg.Add(1)
go func(idx int, turnURI *stun.URI) {
defer wg.Done()
probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
results[idx].URI = turnURI.String()
results[idx].Addr, results[idx].Err = p.probeTURN(probeCtx, turnURI)
}(stunOffset+i, uri)
}
wg.Wait()
p.mu.Lock()
p.cacheResults = results
p.cacheTimestamp = time.Now()
p.cacheKey = cacheKey
p.mu.Unlock()
log.Debug("Stored new probe results in cache")
}
// ProbeSTUN tries binding to the given STUN uri and acquiring an address
func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
defer func() {
if probeErr != nil {
log.Debugf("stun probe error from %s: %s", uri, probeErr)
@@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
}
// ProbeTURN tries allocating a session from the given TURN URI
func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
defer func() {
if probeErr != nil {
log.Debugf("turn probe error from %s: %s", uri, probeErr)
@@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
return relayConn.LocalAddr().String(), nil
}
// ProbeAll probes all given servers asynchronously and returns the results
func ProbeAll(
ctx context.Context,
fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error),
relays []*stun.URI,
) []ProbeResult {
results := make([]ProbeResult, len(relays))
func createErrorResults(stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
total := len(stuns) + len(turns)
results := make([]ProbeResult, total)
var wg sync.WaitGroup
for i, uri := range relays {
ctx, cancel := context.WithTimeout(ctx, 6*time.Second)
defer cancel()
wg.Add(1)
go func(res *ProbeResult, stunURI *stun.URI) {
defer wg.Done()
res.URI = stunURI.String()
res.Addr, res.Err = fn(ctx, stunURI)
}(&results[i], uri)
allURIs := append(append([]*stun.URI{}, stuns...), turns...)
for i, uri := range allURIs {
results[i] = ProbeResult{
URI: uri.String(),
Err: ErrCheckInProgress,
}
}
wg.Wait()
return results
}
func generateCacheKey(stuns []*stun.URI, turns []*stun.URI) string {
h := sha256.New()
for _, uri := range stuns {
h.Write([]byte(uri.String()))
}
for _, uri := range turns {
h.Write([]byte(uri.String()))
}
return fmt.Sprintf("%x", h.Sum(nil))
}

View File

@@ -1057,10 +1057,7 @@ func (s *Server) Status(
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
if msg.GetFullPeerStatus {
if msg.ShouldRunProbes {
s.runProbes()
}
s.runProbes(msg.ShouldRunProbes)
fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
@@ -1070,7 +1067,7 @@ func (s *Server) Status(
return &statusResponse, nil
}
func (s *Server) runProbes() {
func (s *Server) runProbes(waitForProbeResult bool) {
if s.connectClient == nil {
return
}
@@ -1081,7 +1078,7 @@ func (s *Server) runProbes() {
}
if time.Since(s.lastProbe) > probeThreshold {
if engine.RunHealthProbes() {
if engine.RunHealthProbes(waitForProbeResult) {
s.lastProbe = time.Now()
}
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
probeRelay "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/version"
@@ -340,10 +341,16 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
for _, relay := range overview.Relays.Details {
available := "Available"
reason := ""
if !relay.Available {
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
if relay.Error == probeRelay.ErrCheckInProgress.Error() {
available = "Checking..."
} else {
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
}
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
}
} else {