Add ssh authenatication with jwt (#4550)

This commit is contained in:
Viktor Liu
2025-10-07 23:38:27 +02:00
committed by GitHub
parent 7e0bbaaa3c
commit d9efe4e944
50 changed files with 4429 additions and 2336 deletions

View File

@@ -46,6 +46,9 @@ const (
defaultMaxRetryTime = 14 * 24 * time.Hour
defaultRetryMultiplier = 1.7
// JWT token cache TTL for the client daemon
defaultJWTCacheTTL = 5 * time.Minute
errRestoreResidualState = "failed to restore residual state: %v"
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
@@ -81,6 +84,8 @@ type Server struct {
profileManager *profilemanager.ServiceManager
profilesDisabled bool
updateSettingsDisabled bool
jwtCache *jwtCache
}
type oauthAuthFlow struct {
@@ -100,6 +105,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
profileManager: profilemanager.NewServiceManager(configFile),
profilesDisabled: profilesDisabled,
updateSettingsDisabled: updateSettingsDisabled,
jwtCache: newJWTCache(),
}
}
@@ -370,6 +376,9 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.EnableSSHSFTP = msg.EnableSSHSFTP
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForward
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForward
if msg.DisableSSHAuth != nil {
config.DisableSSHAuth = msg.DisableSSHAuth
}
if msg.Mtu != nil {
mtu := uint16(*msg.Mtu)
@@ -486,7 +495,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
return nil, err
}
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) {
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(ctx) {
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
log.Debugf("using previous oauth flow info")
return &proto.LoginResponse{
@@ -503,7 +512,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
}
}
authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
if err != nil {
log.Errorf("getting a request OAuth flow failed: %v", err)
return nil, err
@@ -1077,28 +1086,41 @@ func (s *Server) GetPeerSSHHostKey(
}
s.mutex.Lock()
defer s.mutex.Unlock()
connectClient := s.connectClient
statusRecorder := s.statusRecorder
s.mutex.Unlock()
response := &proto.GetPeerSSHHostKeyResponse{
Found: false,
if connectClient == nil {
return nil, errors.New("client not initialized")
}
if s.statusRecorder == nil {
engine := connectClient.Engine()
if engine == nil {
return nil, errors.New("engine not started")
}
peerAddress := req.GetPeerAddress()
hostKey, found := engine.GetPeerSSHKey(peerAddress)
response := &proto.GetPeerSSHHostKeyResponse{
Found: found,
}
if !found {
return response, nil
}
fullStatus := s.statusRecorder.GetFullStatus()
peerAddress := req.GetPeerAddress()
response.SshHostKey = hostKey
// Search for peer by IP or FQDN
if statusRecorder == nil {
return response, nil
}
fullStatus := statusRecorder.GetFullStatus()
for _, peerState := range fullStatus.Peers {
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
if len(peerState.SSHHostKey) > 0 {
response.SshHostKey = peerState.SSHHostKey
response.PeerIP = peerState.IP
response.PeerFQDN = peerState.FQDN
response.Found = true
}
response.PeerIP = peerState.IP
response.PeerFQDN = peerState.FQDN
break
}
}
@@ -1106,6 +1128,137 @@ func (s *Server) GetPeerSSHHostKey(
return response, nil
}
// getJWTCacheTTL returns the JWT cache TTL from environment variable or default
// NB_SSH_JWT_CACHE_TTL=0 disables caching
// NB_SSH_JWT_CACHE_TTL=<seconds> sets custom cache TTL
func getJWTCacheTTL() time.Duration {
envValue := os.Getenv("NB_SSH_JWT_CACHE_TTL")
if envValue == "" {
return defaultJWTCacheTTL
}
seconds, err := strconv.Atoi(envValue)
if err != nil {
log.Warnf("invalid NB_SSH_JWT_CACHE_TTL value %s, using default: %v", envValue, defaultJWTCacheTTL)
return defaultJWTCacheTTL
}
if seconds == 0 {
log.Info("SSH JWT cache disabled via NB_SSH_JWT_CACHE_TTL=0")
return 0
}
ttl := time.Duration(seconds) * time.Second
log.Infof("SSH JWT cache TTL set to %v via NB_SSH_JWT_CACHE_TTL", ttl)
return ttl
}
// RequestJWTAuth initiates JWT authentication flow for SSH
func (s *Server) RequestJWTAuth(
ctx context.Context,
_ *proto.RequestJWTAuthRequest,
) (*proto.RequestJWTAuthResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock()
config := s.config
s.mutex.Unlock()
if config == nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
}
jwtCacheTTL := getJWTCacheTTL()
if jwtCacheTTL > 0 {
if cachedToken, found := s.jwtCache.get(); found {
log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
return &proto.RequestJWTAuthResponse{
CachedToken: cachedToken,
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
}, nil
}
}
isDesktop := isUnixRunningDesktop()
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
}
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err)
}
s.mutex.Lock()
s.oauthAuthFlow.flow = oAuthFlow
s.oauthAuthFlow.info = authInfo
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second)
s.mutex.Unlock()
return &proto.RequestJWTAuthResponse{
VerificationURI: authInfo.VerificationURI,
VerificationURIComplete: authInfo.VerificationURIComplete,
UserCode: authInfo.UserCode,
DeviceCode: authInfo.DeviceCode,
ExpiresIn: int64(authInfo.ExpiresIn),
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
}, nil
}
// WaitJWTToken waits for JWT authentication completion
func (s *Server) WaitJWTToken(
ctx context.Context,
req *proto.WaitJWTTokenRequest,
) (*proto.WaitJWTTokenResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock()
oAuthFlow := s.oauthAuthFlow.flow
authInfo := s.oauthAuthFlow.info
s.mutex.Unlock()
if oAuthFlow == nil || authInfo.DeviceCode != req.DeviceCode {
return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active auth flow")
}
tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to get token: %v", err)
}
token := tokenInfo.GetTokenToUse()
jwtCacheTTL := getJWTCacheTTL()
if jwtCacheTTL > 0 {
s.jwtCache.store(token, jwtCacheTTL)
log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
} else {
log.Debug("JWT caching disabled, not storing token")
}
s.mutex.Lock()
s.oauthAuthFlow = oauthAuthFlow{}
s.mutex.Unlock()
return &proto.WaitJWTTokenResponse{
Token: tokenInfo.GetTokenToUse(),
TokenType: tokenInfo.TokenType,
ExpiresIn: int64(tokenInfo.ExpiresIn),
}, nil
}
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
}
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}
func (s *Server) runProbes() {
if s.connectClient == nil {
return
@@ -1191,13 +1344,18 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
enableSSHRemotePortForwarding = *s.config.EnableSSHRemotePortForwarding
}
disableSSHAuth := false
if s.config.DisableSSHAuth != nil {
disableSSHAuth = *s.config.DisableSSHAuth
}
return &proto.GetConfigResponse{
ManagementUrl: managementURL.String(),
PreSharedKey: preSharedKey,
AdminURL: adminURL.String(),
InterfaceName: cfg.WgIface,
WireguardPort: int64(cfg.WgPort),
Mtu: int64(cfg.MTU),
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
@@ -1214,6 +1372,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
EnableSSHSFTP: enableSSHSFTP,
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
DisableSSHAuth: disableSSHAuth,
}, nil
}