Merge branch 'main' into fix/login-cmd-root-flags

# Conflicts:
#	client/proto/daemon.pb.go
#	client/proto/daemon.proto
#	client/server/server.go
This commit is contained in:
Zoltán Papp
2025-12-30 09:51:04 +01:00
519 changed files with 52003 additions and 8035 deletions

View File

@@ -46,6 +46,9 @@ const (
defaultMaxRetryTime = 14 * 24 * time.Hour
defaultRetryMultiplier = 1.7
// JWT token cache TTL for the client daemon (disabled by default)
defaultJWTCacheTTL = 0
errRestoreResidualState = "failed to restore residual state: %v"
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
@@ -81,6 +84,11 @@ type Server struct {
profileManager *profilemanager.ServiceManager
profilesDisabled bool
updateSettingsDisabled bool
// sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down
sleepTriggeredDown atomic.Bool
jwtCache *jwtCache
}
type oauthAuthFlow struct {
@@ -100,6 +108,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
profileManager: profilemanager.NewServiceManager(configFile),
profilesDisabled: profilesDisabled,
updateSettingsDisabled: updateSettingsDisabled,
jwtCache: newJWTCache(),
}
}
@@ -183,7 +192,7 @@ func (s *Server) Start() error {
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, s.clientRunningChan, s.clientGiveUpChan)
return nil
}
@@ -214,7 +223,7 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}, giveUpChan chan struct{}) {
defer func() {
s.mutex.Lock()
s.clientRunning = false
@@ -222,7 +231,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
}()
if s.config.DisableAutoConnect {
if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
if err := s.connect(ctx, s.config, s.statusRecorder, doInitialAutoUpdate, runningChan); err != nil {
log.Debugf("run client connection exited with error: %v", err)
}
log.Tracef("client connection exited")
@@ -251,7 +260,8 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
}()
runOperation := func() error {
err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
err := s.connect(ctx, profileConfig, statusRecorder, doInitialAutoUpdate, runningChan)
doInitialAutoUpdate = false
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
return err
@@ -353,6 +363,13 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.CustomDNSAddress = []byte{}
}
config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
if msg.DnsRouteInterval != nil {
interval := msg.DnsRouteInterval.AsDuration()
config.DNSRouteInterval = &interval
}
config.RosenpassEnabled = msg.RosenpassEnabled
config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect
@@ -366,6 +383,17 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.DisableNotifications = msg.DisableNotifications
config.LazyConnectionEnabled = msg.LazyConnectionEnabled
config.BlockInbound = msg.BlockInbound
config.EnableSSHRoot = msg.EnableSSHRoot
config.EnableSSHSFTP = msg.EnableSSHSFTP
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForwarding
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForwarding
if msg.DisableSSHAuth != nil {
config.DisableSSHAuth = msg.DisableSSHAuth
}
if msg.SshJWTCacheTTL != nil {
ttl := int(*msg.SshJWTCacheTTL)
config.SSHJWTCacheTTL = &ttl
}
if msg.Mtu != nil {
mtu := uint16(*msg.Mtu)
@@ -476,13 +504,17 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
state.Set(internal.StatusConnecting)
if msg.SetupKey == "" {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient)
hint := ""
if msg.Hint != nil {
hint = *msg.Hint
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, false, hint)
if err != nil {
state.Set(internal.StatusLoginFailed)
return nil, err
}
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) {
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(ctx) {
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
log.Debugf("using previous oauth flow info")
return &proto.LoginResponse{
@@ -499,7 +531,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
}
}
authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
if err != nil {
log.Errorf("getting a request OAuth flow failed: %v", err)
return nil, err
@@ -697,7 +729,12 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
var doAutoUpdate bool
if msg != nil && msg.AutoUpdate != nil && *msg.AutoUpdate {
doAutoUpdate = true
}
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan)
return s.waitForUp(callerCtx)
}
@@ -791,6 +828,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
defer s.mutex.Unlock()
if err := s.cleanupConnection(); err != nil {
// todo review to update the status in case any type of error
log.Errorf("failed to shut down properly: %v", err)
return nil, err
}
@@ -883,6 +921,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
}
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
// todo review to update the status in case any type of error
log.Errorf("failed to cleanup connection: %v", err)
return nil, err
}
@@ -1050,20 +1089,240 @@ func (s *Server) Status(
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
if msg.GetFullPeerStatus {
if msg.ShouldRunProbes {
s.runProbes()
}
s.runProbes(msg.ShouldRunProbes)
fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
statusResponse.FullStatus = pbFullStatus
}
return &statusResponse, nil
}
func (s *Server) runProbes() {
// getSSHServerState retrieves the current SSH server state including enabled status and active sessions
func (s *Server) getSSHServerState() *proto.SSHServerState {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil
}
engine := connectClient.Engine()
if engine == nil {
return nil
}
enabled, sessions := engine.GetSSHServerStatus()
sshServerState := &proto.SSHServerState{
Enabled: enabled,
}
for _, session := range sessions {
sshServerState.Sessions = append(sshServerState.Sessions, &proto.SSHSessionInfo{
Username: session.Username,
RemoteAddress: session.RemoteAddress,
Command: session.Command,
JwtUsername: session.JWTUsername,
})
}
return sshServerState
}
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
func (s *Server) GetPeerSSHHostKey(
ctx context.Context,
req *proto.GetPeerSSHHostKeyRequest,
) (*proto.GetPeerSSHHostKeyResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock()
connectClient := s.connectClient
statusRecorder := s.statusRecorder
s.mutex.Unlock()
if connectClient == nil {
return nil, errors.New("client not initialized")
}
engine := connectClient.Engine()
if engine == nil {
return nil, errors.New("engine not started")
}
peerAddress := req.GetPeerAddress()
hostKey, found := engine.GetPeerSSHKey(peerAddress)
response := &proto.GetPeerSSHHostKeyResponse{
Found: found,
}
if !found {
return response, nil
}
response.SshHostKey = hostKey
if statusRecorder == nil {
return response, nil
}
fullStatus := statusRecorder.GetFullStatus()
for _, peerState := range fullStatus.Peers {
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
response.PeerIP = peerState.IP
response.PeerFQDN = peerState.FQDN
break
}
}
return response, nil
}
// getJWTCacheTTL returns the JWT cache TTL from config or default (disabled)
func (s *Server) getJWTCacheTTL() time.Duration {
s.mutex.Lock()
config := s.config
s.mutex.Unlock()
if config == nil || config.SSHJWTCacheTTL == nil {
return defaultJWTCacheTTL
}
seconds := *config.SSHJWTCacheTTL
if seconds == 0 {
log.Debug("SSH JWT cache disabled (configured to 0)")
return 0
}
ttl := time.Duration(seconds) * time.Second
log.Debugf("SSH JWT cache TTL set to %v from config", ttl)
return ttl
}
// RequestJWTAuth initiates JWT authentication flow for SSH
func (s *Server) RequestJWTAuth(
ctx context.Context,
msg *proto.RequestJWTAuthRequest,
) (*proto.RequestJWTAuthResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock()
config := s.config
s.mutex.Unlock()
if config == nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
}
jwtCacheTTL := s.getJWTCacheTTL()
if jwtCacheTTL > 0 {
if cachedToken, found := s.jwtCache.get(); found {
log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
return &proto.RequestJWTAuthResponse{
CachedToken: cachedToken,
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
}, nil
}
}
hint := ""
if msg.Hint != nil {
hint = *msg.Hint
}
if hint == "" {
hint = profilemanager.GetLoginHint()
}
isDesktop := isUnixRunningDesktop()
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, false, hint)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
}
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err)
}
s.mutex.Lock()
s.oauthAuthFlow.flow = oAuthFlow
s.oauthAuthFlow.info = authInfo
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second)
s.mutex.Unlock()
return &proto.RequestJWTAuthResponse{
VerificationURI: authInfo.VerificationURI,
VerificationURIComplete: authInfo.VerificationURIComplete,
UserCode: authInfo.UserCode,
DeviceCode: authInfo.DeviceCode,
ExpiresIn: int64(authInfo.ExpiresIn),
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
}, nil
}
// WaitJWTToken waits for JWT authentication completion
func (s *Server) WaitJWTToken(
ctx context.Context,
req *proto.WaitJWTTokenRequest,
) (*proto.WaitJWTTokenResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock()
oAuthFlow := s.oauthAuthFlow.flow
authInfo := s.oauthAuthFlow.info
s.mutex.Unlock()
if oAuthFlow == nil || authInfo.DeviceCode != req.DeviceCode {
return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active auth flow")
}
tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to get token: %v", err)
}
token := tokenInfo.GetTokenToUse()
jwtCacheTTL := s.getJWTCacheTTL()
if jwtCacheTTL > 0 {
s.jwtCache.store(token, jwtCacheTTL)
log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
} else {
log.Debug("JWT caching disabled, not storing token")
}
s.mutex.Lock()
s.oauthAuthFlow = oauthAuthFlow{}
s.mutex.Unlock()
return &proto.WaitJWTTokenResponse{
Token: tokenInfo.GetTokenToUse(),
TokenType: tokenInfo.TokenType,
ExpiresIn: int64(tokenInfo.ExpiresIn),
}, nil
}
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
}
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}
func (s *Server) runProbes(waitForProbeResult bool) {
if s.connectClient == nil {
return
}
@@ -1074,7 +1333,7 @@ func (s *Server) runProbes() {
}
if time.Since(s.lastProbe) > probeThreshold {
if engine.RunHealthProbes() {
if engine.RunHealthProbes(waitForProbeResult) {
s.lastProbe = time.Now()
}
}
@@ -1129,26 +1388,62 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
blockLANAccess := cfg.BlockLANAccess
disableFirewall := cfg.DisableFirewall
enableSSHRoot := false
if cfg.EnableSSHRoot != nil {
enableSSHRoot = *cfg.EnableSSHRoot
}
enableSSHSFTP := false
if cfg.EnableSSHSFTP != nil {
enableSSHSFTP = *cfg.EnableSSHSFTP
}
enableSSHLocalPortForwarding := false
if cfg.EnableSSHLocalPortForwarding != nil {
enableSSHLocalPortForwarding = *cfg.EnableSSHLocalPortForwarding
}
enableSSHRemotePortForwarding := false
if cfg.EnableSSHRemotePortForwarding != nil {
enableSSHRemotePortForwarding = *cfg.EnableSSHRemotePortForwarding
}
disableSSHAuth := false
if cfg.DisableSSHAuth != nil {
disableSSHAuth = *cfg.DisableSSHAuth
}
sshJWTCacheTTL := int32(0)
if cfg.SSHJWTCacheTTL != nil {
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
}
return &proto.GetConfigResponse{
ManagementUrl: managementURL.String(),
PreSharedKey: preSharedKey,
AdminURL: adminURL.String(),
InterfaceName: cfg.WgIface,
WireguardPort: int64(cfg.WgPort),
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
BlockInbound: cfg.BlockInbound,
DisableNotifications: disableNotifications,
NetworkMonitor: networkMonitor,
DisableDns: disableDNS,
DisableClientRoutes: disableClientRoutes,
DisableServerRoutes: disableServerRoutes,
BlockLanAccess: blockLANAccess,
DisableFirewall: disableFirewall,
ManagementUrl: managementURL.String(),
PreSharedKey: preSharedKey,
AdminURL: adminURL.String(),
InterfaceName: cfg.WgIface,
WireguardPort: int64(cfg.WgPort),
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
BlockInbound: cfg.BlockInbound,
DisableNotifications: disableNotifications,
NetworkMonitor: networkMonitor,
DisableDns: disableDNS,
DisableClientRoutes: disableClientRoutes,
DisableServerRoutes: disableServerRoutes,
BlockLanAccess: blockLANAccess,
EnableSSHRoot: enableSSHRoot,
EnableSSHSFTP: enableSSHSFTP,
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
DisableSSHAuth: disableSSHAuth,
SshJWTCacheTTL: sshJWTCacheTTL,
DisableFirewall: disableFirewall,
}, nil
}
@@ -1252,9 +1547,9 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
return features, nil
}
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
if err := s.connectClient.Run(runningChan); err != nil {
return err
@@ -1379,6 +1674,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}