mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client,management] Rewrite the SSH feature (#4015)
This commit is contained in:
@@ -46,6 +46,9 @@ const (
|
||||
defaultMaxRetryTime = 14 * 24 * time.Hour
|
||||
defaultRetryMultiplier = 1.7
|
||||
|
||||
// JWT token cache TTL for the client daemon (disabled by default)
|
||||
defaultJWTCacheTTL = 0
|
||||
|
||||
errRestoreResidualState = "failed to restore residual state: %v"
|
||||
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
|
||||
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
|
||||
@@ -81,6 +84,8 @@ type Server struct {
|
||||
profileManager *profilemanager.ServiceManager
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
|
||||
jwtCache *jwtCache
|
||||
}
|
||||
|
||||
type oauthAuthFlow struct {
|
||||
@@ -100,6 +105,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
profileManager: profilemanager.NewServiceManager(configFile),
|
||||
profilesDisabled: profilesDisabled,
|
||||
updateSettingsDisabled: updateSettingsDisabled,
|
||||
jwtCache: newJWTCache(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -373,6 +379,17 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
config.DisableNotifications = msg.DisableNotifications
|
||||
config.LazyConnectionEnabled = msg.LazyConnectionEnabled
|
||||
config.BlockInbound = msg.BlockInbound
|
||||
config.EnableSSHRoot = msg.EnableSSHRoot
|
||||
config.EnableSSHSFTP = msg.EnableSSHSFTP
|
||||
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForwarding
|
||||
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForwarding
|
||||
if msg.DisableSSHAuth != nil {
|
||||
config.DisableSSHAuth = msg.DisableSSHAuth
|
||||
}
|
||||
if msg.SshJWTCacheTTL != nil {
|
||||
ttl := int(*msg.SshJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = &ttl
|
||||
}
|
||||
|
||||
if msg.Mtu != nil {
|
||||
mtu := uint16(*msg.Mtu)
|
||||
@@ -493,7 +510,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) {
|
||||
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(ctx) {
|
||||
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
|
||||
log.Debugf("using previous oauth flow info")
|
||||
return &proto.LoginResponse{
|
||||
@@ -510,7 +527,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
}
|
||||
}
|
||||
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("getting a request OAuth flow failed: %v", err)
|
||||
return nil, err
|
||||
@@ -1065,12 +1082,235 @@ func (s *Server) Status(
|
||||
fullStatus := s.statusRecorder.GetFullStatus()
|
||||
pbFullStatus := toProtoFullStatus(fullStatus)
|
||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||
|
||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||
|
||||
statusResponse.FullStatus = pbFullStatus
|
||||
}
|
||||
|
||||
return &statusResponse, nil
|
||||
}
|
||||
|
||||
// getSSHServerState retrieves the current SSH server state including enabled status and active sessions
|
||||
func (s *Server) getSSHServerState() *proto.SSHServerState {
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
enabled, sessions := engine.GetSSHServerStatus()
|
||||
sshServerState := &proto.SSHServerState{
|
||||
Enabled: enabled,
|
||||
}
|
||||
|
||||
for _, session := range sessions {
|
||||
sshServerState.Sessions = append(sshServerState.Sessions, &proto.SSHSessionInfo{
|
||||
Username: session.Username,
|
||||
RemoteAddress: session.RemoteAddress,
|
||||
Command: session.Command,
|
||||
JwtUsername: session.JWTUsername,
|
||||
})
|
||||
}
|
||||
|
||||
return sshServerState
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
func (s *Server) GetPeerSSHHostKey(
|
||||
ctx context.Context,
|
||||
req *proto.GetPeerSSHHostKeyRequest,
|
||||
) (*proto.GetPeerSSHHostKeyResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
statusRecorder := s.statusRecorder
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil, errors.New("client not initialized")
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine not started")
|
||||
}
|
||||
|
||||
peerAddress := req.GetPeerAddress()
|
||||
hostKey, found := engine.GetPeerSSHKey(peerAddress)
|
||||
|
||||
response := &proto.GetPeerSSHHostKeyResponse{
|
||||
Found: found,
|
||||
}
|
||||
|
||||
if !found {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
response.SshHostKey = hostKey
|
||||
|
||||
if statusRecorder == nil {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
fullStatus := statusRecorder.GetFullStatus()
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
|
||||
response.PeerIP = peerState.IP
|
||||
response.PeerFQDN = peerState.FQDN
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// getJWTCacheTTL returns the JWT cache TTL from config or default (disabled)
|
||||
func (s *Server) getJWTCacheTTL() time.Duration {
|
||||
s.mutex.Lock()
|
||||
config := s.config
|
||||
s.mutex.Unlock()
|
||||
|
||||
if config == nil || config.SSHJWTCacheTTL == nil {
|
||||
return defaultJWTCacheTTL
|
||||
}
|
||||
|
||||
seconds := *config.SSHJWTCacheTTL
|
||||
if seconds == 0 {
|
||||
log.Debug("SSH JWT cache disabled (configured to 0)")
|
||||
return 0
|
||||
}
|
||||
|
||||
ttl := time.Duration(seconds) * time.Second
|
||||
log.Debugf("SSH JWT cache TTL set to %v from config", ttl)
|
||||
return ttl
|
||||
}
|
||||
|
||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||
func (s *Server) RequestJWTAuth(
|
||||
ctx context.Context,
|
||||
msg *proto.RequestJWTAuthRequest,
|
||||
) (*proto.RequestJWTAuthResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
config := s.config
|
||||
s.mutex.Unlock()
|
||||
|
||||
if config == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
|
||||
}
|
||||
|
||||
jwtCacheTTL := s.getJWTCacheTTL()
|
||||
if jwtCacheTTL > 0 {
|
||||
if cachedToken, found := s.jwtCache.get(); found {
|
||||
log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
|
||||
|
||||
return &proto.RequestJWTAuthResponse{
|
||||
CachedToken: cachedToken,
|
||||
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
hint := ""
|
||||
if msg.Hint != nil {
|
||||
hint = *msg.Hint
|
||||
}
|
||||
|
||||
if hint == "" {
|
||||
hint = profilemanager.GetLoginHint()
|
||||
}
|
||||
|
||||
isDesktop := isUnixRunningDesktop()
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, hint)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
|
||||
}
|
||||
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err)
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
s.oauthAuthFlow.flow = oAuthFlow
|
||||
s.oauthAuthFlow.info = authInfo
|
||||
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second)
|
||||
s.mutex.Unlock()
|
||||
|
||||
return &proto.RequestJWTAuthResponse{
|
||||
VerificationURI: authInfo.VerificationURI,
|
||||
VerificationURIComplete: authInfo.VerificationURIComplete,
|
||||
UserCode: authInfo.UserCode,
|
||||
DeviceCode: authInfo.DeviceCode,
|
||||
ExpiresIn: int64(authInfo.ExpiresIn),
|
||||
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
func (s *Server) WaitJWTToken(
|
||||
ctx context.Context,
|
||||
req *proto.WaitJWTTokenRequest,
|
||||
) (*proto.WaitJWTTokenResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
oAuthFlow := s.oauthAuthFlow.flow
|
||||
authInfo := s.oauthAuthFlow.info
|
||||
s.mutex.Unlock()
|
||||
|
||||
if oAuthFlow == nil || authInfo.DeviceCode != req.DeviceCode {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active auth flow")
|
||||
}
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to get token: %v", err)
|
||||
}
|
||||
|
||||
token := tokenInfo.GetTokenToUse()
|
||||
|
||||
jwtCacheTTL := s.getJWTCacheTTL()
|
||||
if jwtCacheTTL > 0 {
|
||||
s.jwtCache.store(token, jwtCacheTTL)
|
||||
log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
|
||||
} else {
|
||||
log.Debug("JWT caching disabled, not storing token")
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
s.oauthAuthFlow = oauthAuthFlow{}
|
||||
s.mutex.Unlock()
|
||||
return &proto.WaitJWTTokenResponse{
|
||||
Token: tokenInfo.GetTokenToUse(),
|
||||
TokenType: tokenInfo.TokenType,
|
||||
ExpiresIn: int64(tokenInfo.ExpiresIn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
return false
|
||||
}
|
||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||
}
|
||||
|
||||
func (s *Server) runProbes(waitForProbeResult bool) {
|
||||
if s.connectClient == nil {
|
||||
return
|
||||
@@ -1136,25 +1376,61 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
disableServerRoutes := cfg.DisableServerRoutes
|
||||
blockLANAccess := cfg.BlockLANAccess
|
||||
|
||||
enableSSHRoot := false
|
||||
if cfg.EnableSSHRoot != nil {
|
||||
enableSSHRoot = *cfg.EnableSSHRoot
|
||||
}
|
||||
|
||||
enableSSHSFTP := false
|
||||
if cfg.EnableSSHSFTP != nil {
|
||||
enableSSHSFTP = *cfg.EnableSSHSFTP
|
||||
}
|
||||
|
||||
enableSSHLocalPortForwarding := false
|
||||
if cfg.EnableSSHLocalPortForwarding != nil {
|
||||
enableSSHLocalPortForwarding = *cfg.EnableSSHLocalPortForwarding
|
||||
}
|
||||
|
||||
enableSSHRemotePortForwarding := false
|
||||
if cfg.EnableSSHRemotePortForwarding != nil {
|
||||
enableSSHRemotePortForwarding = *cfg.EnableSSHRemotePortForwarding
|
||||
}
|
||||
|
||||
disableSSHAuth := false
|
||||
if cfg.DisableSSHAuth != nil {
|
||||
disableSSHAuth = *cfg.DisableSSHAuth
|
||||
}
|
||||
|
||||
sshJWTCacheTTL := int32(0)
|
||||
if cfg.SSHJWTCacheTTL != nil {
|
||||
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
|
||||
}
|
||||
|
||||
return &proto.GetConfigResponse{
|
||||
ManagementUrl: managementURL.String(),
|
||||
PreSharedKey: preSharedKey,
|
||||
AdminURL: adminURL.String(),
|
||||
InterfaceName: cfg.WgIface,
|
||||
WireguardPort: int64(cfg.WgPort),
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
BlockInbound: cfg.BlockInbound,
|
||||
DisableNotifications: disableNotifications,
|
||||
NetworkMonitor: networkMonitor,
|
||||
DisableDns: disableDNS,
|
||||
DisableClientRoutes: disableClientRoutes,
|
||||
DisableServerRoutes: disableServerRoutes,
|
||||
BlockLanAccess: blockLANAccess,
|
||||
ManagementUrl: managementURL.String(),
|
||||
PreSharedKey: preSharedKey,
|
||||
AdminURL: adminURL.String(),
|
||||
InterfaceName: cfg.WgIface,
|
||||
WireguardPort: int64(cfg.WgPort),
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
BlockInbound: cfg.BlockInbound,
|
||||
DisableNotifications: disableNotifications,
|
||||
NetworkMonitor: networkMonitor,
|
||||
DisableDns: disableDNS,
|
||||
DisableClientRoutes: disableClientRoutes,
|
||||
DisableServerRoutes: disableServerRoutes,
|
||||
BlockLanAccess: blockLANAccess,
|
||||
EnableSSHRoot: enableSSHRoot,
|
||||
EnableSSHSFTP: enableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
|
||||
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: disableSSHAuth,
|
||||
SshJWTCacheTTL: sshJWTCacheTTL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1385,6 +1661,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||
Networks: maps.Keys(peerState.GetRoutes()),
|
||||
Latency: durationpb.New(peerState.Latency),
|
||||
SshHostKey: peerState.SSHHostKey,
|
||||
}
|
||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user