[client] Check the client status in the earlier phase (#4509)

This PR improves the NetBird client's status checking mechanism by implementing earlier detection of client state changes and better handling of connection lifecycle management. The key improvements focus on:

  • Enhanced status detection - Added waitForReady option to StatusRequest for improved client status handling
  • Better connection management - Improved context handling for signal and management gRPC connections• Reduced connection timeouts - Increased gRPC dial timeout from 3 to 10 seconds for better reliability
  • Cleaner error handling - Enhanced error propagation and context cancellation in retry loops

  Key Changes

  Core Status Improvements:
  - Added waitForReady optional field to StatusRequest proto (daemon.proto:190)
  - Enhanced status checking logic to detect client state changes earlier in the connection process
  - Improved handling of client permanent exit scenarios from retry loops

  Connection & Context Management:
  - Fixed context cancellation in management and signal client retry mechanisms
  - Added proper context propagation for Login operations
  - Enhanced gRPC connection handling with better timeout management

  Error Handling & Cleanup:
  - Moved feedback channels to upper layers for better separation of concerns
  - Improved error handling patterns throughout the client server implementation
  - Fixed synchronization issues and removed debug logging
This commit is contained in:
Zoltan Papp
2025-09-20 22:14:01 +02:00
committed by GitHub
parent e254b4cde5
commit 998fb30e1e
10 changed files with 128 additions and 54 deletions

View File

@@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
// DialClientGRPCServer returns client connection to the daemon server. // DialClientGRPCServer returns client connection to the daemon server.
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) { func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*3) ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() defer cancel()
return grpc.DialContext( return grpc.DialContext(

View File

@@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{}) status, err := client.Status(ctx, &proto.StatusRequest{
WaitForReady: func() *bool { b := true; return &b }(),
})
if err != nil { if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err) return fmt.Errorf("unable to get daemon status: %v", err)
} }

View File

@@ -135,7 +135,7 @@ func (c *Client) Start(startCtx context.Context) error {
// either startup error (permanent backoff err) or nil err (successful engine up) // either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available // TODO: make after-startup backoff err available
run := make(chan struct{}, 1) run := make(chan struct{})
clientErr := make(chan error, 1) clientErr := make(chan error, 1)
go func() { go func() {
if err := client.Run(run); err != nil { if err := client.Run(run); err != nil {

View File

@@ -58,7 +58,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx) return backoff.WithContext(b, ctx)
} }
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled { if tlsEnabled {
certPool, err := x509.SystemCertPool() certPool, err := x509.SystemCertPool()
@@ -72,7 +72,7 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
})) }))
} }
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel() defer cancel()
conn, err := grpc.DialContext( conn, err := grpc.DialContext(

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.36.6 // protoc-gen-go v1.36.6
// protoc v5.29.3 // protoc v6.32.1
// source: daemon.proto // source: daemon.proto
package proto package proto
@@ -794,8 +794,10 @@ type StatusRequest struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"` GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"`
ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"` ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"`
unknownFields protoimpl.UnknownFields // the UI do not using this yet, but CLIs could use it to wait until the status is ready
sizeCache protoimpl.SizeCache WaitForReady *bool `protobuf:"varint,3,opt,name=waitForReady,proto3,oneof" json:"waitForReady,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
} }
func (x *StatusRequest) Reset() { func (x *StatusRequest) Reset() {
@@ -842,6 +844,13 @@ func (x *StatusRequest) GetShouldRunProbes() bool {
return false return false
} }
func (x *StatusRequest) GetWaitForReady() bool {
if x != nil && x.WaitForReady != nil {
return *x.WaitForReady
}
return false
}
type StatusResponse struct { type StatusResponse struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
// status of the server. // status of the server.
@@ -4673,10 +4682,12 @@ const file_daemon_proto_rawDesc = "" +
"\f_profileNameB\v\n" + "\f_profileNameB\v\n" +
"\t_username\"\f\n" + "\t_username\"\f\n" +
"\n" + "\n" +
"UpResponse\"g\n" + "UpResponse\"\xa1\x01\n" +
"\rStatusRequest\x12,\n" + "\rStatusRequest\x12,\n" +
"\x11getFullPeerStatus\x18\x01 \x01(\bR\x11getFullPeerStatus\x12(\n" + "\x11getFullPeerStatus\x18\x01 \x01(\bR\x11getFullPeerStatus\x12(\n" +
"\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\"\x82\x01\n" + "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\x12'\n" +
"\fwaitForReady\x18\x03 \x01(\bH\x00R\fwaitForReady\x88\x01\x01B\x0f\n" +
"\r_waitForReady\"\x82\x01\n" +
"\x0eStatusResponse\x12\x16\n" + "\x0eStatusResponse\x12\x16\n" +
"\x06status\x18\x01 \x01(\tR\x06status\x122\n" + "\x06status\x18\x01 \x01(\tR\x06status\x122\n" +
"\n" + "\n" +
@@ -5231,6 +5242,7 @@ func file_daemon_proto_init() {
} }
file_daemon_proto_msgTypes[1].OneofWrappers = []any{} file_daemon_proto_msgTypes[1].OneofWrappers = []any{}
file_daemon_proto_msgTypes[5].OneofWrappers = []any{} file_daemon_proto_msgTypes[5].OneofWrappers = []any{}
file_daemon_proto_msgTypes[7].OneofWrappers = []any{}
file_daemon_proto_msgTypes[26].OneofWrappers = []any{ file_daemon_proto_msgTypes[26].OneofWrappers = []any{
(*PortInfo_Port)(nil), (*PortInfo_Port)(nil),
(*PortInfo_Range_)(nil), (*PortInfo_Range_)(nil),

View File

@@ -186,6 +186,8 @@ message UpResponse {}
message StatusRequest{ message StatusRequest{
bool getFullPeerStatus = 1; bool getFullPeerStatus = 1;
bool shouldRunProbes = 2; bool shouldRunProbes = 2;
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
optional bool waitForReady = 3;
} }
message StatusResponse{ message StatusResponse{

View File

@@ -67,6 +67,7 @@ type Server struct {
proto.UnimplementedDaemonServiceServer proto.UnimplementedDaemonServiceServer
clientRunning bool // protected by mutex clientRunning bool // protected by mutex
clientRunningChan chan struct{} clientRunningChan chan struct{}
clientGiveUpChan chan struct{}
connectClient *internal.ConnectClient connectClient *internal.ConnectClient
@@ -106,6 +107,10 @@ func (s *Server) Start() error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if s.clientRunning {
return nil
}
state := internal.CtxGetState(s.rootCtx) state := internal.CtxGetState(s.rootCtx)
if err := handlePanicLog(); err != nil { if err := handlePanicLog(); err != nil {
@@ -175,12 +180,10 @@ func (s *Server) Start() error {
return nil return nil
} }
if s.clientRunning {
return nil
}
s.clientRunning = true s.clientRunning = true
s.clientRunningChan = make(chan struct{}, 1) s.clientRunningChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan) s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
return nil return nil
} }
@@ -211,7 +214,7 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost. // mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) { func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
defer func() { defer func() {
s.mutex.Lock() s.mutex.Lock()
s.clientRunning = false s.clientRunning = false
@@ -261,6 +264,10 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
if err := backoff.Retry(runOperation, backOff); err != nil { if err := backoff.Retry(runOperation, backOff); err != nil {
log.Errorf("operation failed: %v", err) log.Errorf("operation failed: %v", err)
} }
if giveUpChan != nil {
close(giveUpChan)
}
} }
// loginAttempt attempts to login using the provided information. it returns a status in case something fails // loginAttempt attempts to login using the provided information. it returns a status in case something fails
@@ -379,7 +386,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
if s.actCancel != nil { if s.actCancel != nil {
s.actCancel() s.actCancel()
} }
ctx, cancel := context.WithCancel(s.rootCtx) ctx, cancel := context.WithCancel(callerCtx)
md, ok := metadata.FromIncomingContext(callerCtx) md, ok := metadata.FromIncomingContext(callerCtx)
if ok { if ok {
@@ -389,11 +396,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.actCancel = cancel s.actCancel = cancel
s.mutex.Unlock() s.mutex.Unlock()
if err := restoreResidualState(ctx, s.profileManager.GetStatePath()); err != nil { if err := restoreResidualState(s.rootCtx, s.profileManager.GetStatePath()); err != nil {
log.Warnf(errRestoreResidualState, err) log.Warnf(errRestoreResidualState, err)
} }
state := internal.CtxGetState(ctx) state := internal.CtxGetState(s.rootCtx)
defer func() { defer func() {
status, err := state.Status() status, err := state.Status()
if err != nil || (status != internal.StatusNeedsLogin && status != internal.StatusLoginFailed) { if err != nil || (status != internal.StatusNeedsLogin && status != internal.StatusLoginFailed) {
@@ -606,6 +613,20 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
// Up starts engine work in the daemon. // Up starts engine work in the daemon.
func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) { func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) {
s.mutex.Lock() s.mutex.Lock()
if s.clientRunning {
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
s.mutex.Unlock()
return nil, err
}
if status == internal.StatusNeedsLogin {
s.actCancel()
}
s.mutex.Unlock()
return s.waitForUp(callerCtx)
}
defer s.mutex.Unlock() defer s.mutex.Unlock()
if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil { if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil {
@@ -621,16 +642,16 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
if err != nil { if err != nil {
return nil, err return nil, err
} }
if status != internal.StatusIdle { if status != internal.StatusIdle {
return nil, fmt.Errorf("up already in progress: current status %s", status) return nil, fmt.Errorf("up already in progress: current status %s", status)
} }
// it should be nil here, but . // it should be nil here, but in case it isn't we cancel it.
if s.actCancel != nil { if s.actCancel != nil {
s.actCancel() s.actCancel()
} }
ctx, cancel := context.WithCancel(s.rootCtx) ctx, cancel := context.WithCancel(s.rootCtx)
md, ok := metadata.FromIncomingContext(callerCtx) md, ok := metadata.FromIncomingContext(callerCtx)
if ok { if ok {
ctx = metadata.NewOutgoingContext(ctx, md) ctx = metadata.NewOutgoingContext(ctx, md)
@@ -673,26 +694,31 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
return s.waitForUp(callerCtx)
}
// todo: handle potential race conditions
func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) {
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second) timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
defer cancel() defer cancel()
if !s.clientRunning { select {
s.clientRunning = true case <-s.clientGiveUpChan:
s.clientRunningChan = make(chan struct{}, 1) return nil, fmt.Errorf("client gave up to connect")
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan) case <-s.clientRunningChan:
} s.isSessionActive.Store(true)
for { return &proto.UpResponse{}, nil
select { case <-callerCtx.Done():
case <-s.clientRunningChan: log.Debug("context done, stopping the wait for engine to become ready")
s.isSessionActive.Store(true) return nil, callerCtx.Err()
return &proto.UpResponse{}, nil case <-timeoutCtx.Done():
case <-callerCtx.Done(): log.Debug("up is timed out, stopping the wait for engine to become ready")
log.Debug("context done, stopping the wait for engine to become ready") return nil, timeoutCtx.Err()
return nil, callerCtx.Err()
case <-timeoutCtx.Done():
log.Debug("up is timed out, stopping the wait for engine to become ready")
return nil, timeoutCtx.Err()
}
} }
} }
@@ -966,12 +992,46 @@ func (s *Server) Status(
ctx context.Context, ctx context.Context,
msg *proto.StatusRequest, msg *proto.StatusRequest,
) (*proto.StatusResponse, error) { ) (*proto.StatusResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() clientRunning := s.clientRunning
s.mutex.Unlock()
if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning {
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
return nil, err
}
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
s.actCancel()
}
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
loop:
for {
select {
case <-s.clientGiveUpChan:
ticker.Stop()
break loop
case <-s.clientRunningChan:
ticker.Stop()
break loop
case <-ticker.C:
status, err := state.Status()
if err != nil {
continue
}
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
s.actCancel()
}
continue
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
status, err := internal.CtxGetState(s.rootCtx).Status() status, err := internal.CtxGetState(s.rootCtx).Status()
if err != nil { if err != nil {

View File

@@ -105,7 +105,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s") t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1") t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
if counter < 3 { if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter) t.Fatalf("expected counter > 2, got %d", counter)
} }
@@ -134,8 +134,12 @@ func TestServer_Up(t *testing.T) {
profName := "default" profName := "default"
u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
require.NoError(t, err)
ic := profilemanager.ConfigInput{ ic := profilemanager.ConfigInput{
ConfigPath: filepath.Join(tempDir, profName+".json"), ConfigPath: filepath.Join(tempDir, profName+".json"),
ManagementURL: u.String(),
} }
_, err = profilemanager.UpdateOrCreateConfig(ic) _, err = profilemanager.UpdateOrCreateConfig(ic)
@@ -153,16 +157,9 @@ func TestServer_Up(t *testing.T) {
} }
s := New(ctx, "console", "", false, false) s := New(ctx, "console", "", false, false)
err = s.Start() err = s.Start()
require.NoError(t, err) require.NoError(t, err)
u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
require.NoError(t, err)
s.config = &profilemanager.Config{
ManagementURL: u,
}
upCtx, cancel := context.WithTimeout(ctx, 1*time.Second) upCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel() defer cancel()
@@ -171,6 +168,7 @@ func TestServer_Up(t *testing.T) {
Username: &currUser.Username, Username: &currUser.Username,
} }
_, err = s.Up(upCtx, upReq) _, err = s.Up(upCtx, upReq)
log.Errorf("error from Up: %v", err)
assert.Contains(t, err.Error(), "context deadline exceeded") assert.Contains(t, err.Error(), "context deadline exceeded")
} }

View File

@@ -52,7 +52,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
operation := func() error { operation := func() error {
var err error var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled)
if err != nil { if err != nil {
log.Printf("createConnection error: %v", err) log.Printf("createConnection error: %v", err)
return err return err

View File

@@ -57,7 +57,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
operation := func() error { operation := func() error {
var err error var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled)
if err != nil { if err != nil {
log.Printf("createConnection error: %v", err) log.Printf("createConnection error: %v", err)
return err return err