mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 17:56:39 +00:00
Add ssh authenatication with jwt (#4550)
This commit is contained in:
73
client/server/jwt_cache.go
Normal file
73
client/server/jwt_cache.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type jwtCache struct {
|
||||
mu sync.RWMutex
|
||||
enclave *memguard.Enclave
|
||||
expiresAt time.Time
|
||||
timer *time.Timer
|
||||
maxTokenSize int
|
||||
}
|
||||
|
||||
func newJWTCache() *jwtCache {
|
||||
return &jwtCache{
|
||||
maxTokenSize: 8192,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *jwtCache) store(token string, maxAge time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.cleanup()
|
||||
|
||||
if c.timer != nil {
|
||||
c.timer.Stop()
|
||||
}
|
||||
|
||||
tokenBytes := []byte(token)
|
||||
c.enclave = memguard.NewEnclave(tokenBytes)
|
||||
|
||||
c.expiresAt = time.Now().Add(maxAge)
|
||||
|
||||
c.timer = time.AfterFunc(maxAge, func() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.cleanup()
|
||||
c.timer = nil
|
||||
log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *jwtCache) get() (string, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if c.enclave == nil || time.Now().After(c.expiresAt) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
buffer, err := c.enclave.Open()
|
||||
if err != nil {
|
||||
log.Debugf("Failed to open JWT token enclave: %v", err)
|
||||
return "", false
|
||||
}
|
||||
defer buffer.Destroy()
|
||||
|
||||
token := string(buffer.Bytes())
|
||||
return token, true
|
||||
}
|
||||
|
||||
// cleanup destroys the secure enclave, must be called with lock held
|
||||
func (c *jwtCache) cleanup() {
|
||||
if c.enclave != nil {
|
||||
c.enclave = nil
|
||||
}
|
||||
}
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type selectRoute struct {
|
||||
|
||||
@@ -46,6 +46,9 @@ const (
|
||||
defaultMaxRetryTime = 14 * 24 * time.Hour
|
||||
defaultRetryMultiplier = 1.7
|
||||
|
||||
// JWT token cache TTL for the client daemon
|
||||
defaultJWTCacheTTL = 5 * time.Minute
|
||||
|
||||
errRestoreResidualState = "failed to restore residual state: %v"
|
||||
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
|
||||
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
|
||||
@@ -81,6 +84,8 @@ type Server struct {
|
||||
profileManager *profilemanager.ServiceManager
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
|
||||
jwtCache *jwtCache
|
||||
}
|
||||
|
||||
type oauthAuthFlow struct {
|
||||
@@ -100,6 +105,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
profileManager: profilemanager.NewServiceManager(configFile),
|
||||
profilesDisabled: profilesDisabled,
|
||||
updateSettingsDisabled: updateSettingsDisabled,
|
||||
jwtCache: newJWTCache(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -370,6 +376,9 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
config.EnableSSHSFTP = msg.EnableSSHSFTP
|
||||
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForward
|
||||
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForward
|
||||
if msg.DisableSSHAuth != nil {
|
||||
config.DisableSSHAuth = msg.DisableSSHAuth
|
||||
}
|
||||
|
||||
if msg.Mtu != nil {
|
||||
mtu := uint16(*msg.Mtu)
|
||||
@@ -486,7 +495,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) {
|
||||
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(ctx) {
|
||||
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
|
||||
log.Debugf("using previous oauth flow info")
|
||||
return &proto.LoginResponse{
|
||||
@@ -503,7 +512,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
}
|
||||
}
|
||||
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("getting a request OAuth flow failed: %v", err)
|
||||
return nil, err
|
||||
@@ -1077,28 +1086,41 @@ func (s *Server) GetPeerSSHHostKey(
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
connectClient := s.connectClient
|
||||
statusRecorder := s.statusRecorder
|
||||
s.mutex.Unlock()
|
||||
|
||||
response := &proto.GetPeerSSHHostKeyResponse{
|
||||
Found: false,
|
||||
if connectClient == nil {
|
||||
return nil, errors.New("client not initialized")
|
||||
}
|
||||
|
||||
if s.statusRecorder == nil {
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine not started")
|
||||
}
|
||||
|
||||
peerAddress := req.GetPeerAddress()
|
||||
hostKey, found := engine.GetPeerSSHKey(peerAddress)
|
||||
|
||||
response := &proto.GetPeerSSHHostKeyResponse{
|
||||
Found: found,
|
||||
}
|
||||
|
||||
if !found {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
fullStatus := s.statusRecorder.GetFullStatus()
|
||||
peerAddress := req.GetPeerAddress()
|
||||
response.SshHostKey = hostKey
|
||||
|
||||
// Search for peer by IP or FQDN
|
||||
if statusRecorder == nil {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
fullStatus := statusRecorder.GetFullStatus()
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
|
||||
if len(peerState.SSHHostKey) > 0 {
|
||||
response.SshHostKey = peerState.SSHHostKey
|
||||
response.PeerIP = peerState.IP
|
||||
response.PeerFQDN = peerState.FQDN
|
||||
response.Found = true
|
||||
}
|
||||
response.PeerIP = peerState.IP
|
||||
response.PeerFQDN = peerState.FQDN
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -1106,6 +1128,137 @@ func (s *Server) GetPeerSSHHostKey(
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// getJWTCacheTTL returns the JWT cache TTL from environment variable or default
|
||||
// NB_SSH_JWT_CACHE_TTL=0 disables caching
|
||||
// NB_SSH_JWT_CACHE_TTL=<seconds> sets custom cache TTL
|
||||
func getJWTCacheTTL() time.Duration {
|
||||
envValue := os.Getenv("NB_SSH_JWT_CACHE_TTL")
|
||||
if envValue == "" {
|
||||
return defaultJWTCacheTTL
|
||||
}
|
||||
|
||||
seconds, err := strconv.Atoi(envValue)
|
||||
if err != nil {
|
||||
log.Warnf("invalid NB_SSH_JWT_CACHE_TTL value %s, using default: %v", envValue, defaultJWTCacheTTL)
|
||||
return defaultJWTCacheTTL
|
||||
}
|
||||
|
||||
if seconds == 0 {
|
||||
log.Info("SSH JWT cache disabled via NB_SSH_JWT_CACHE_TTL=0")
|
||||
return 0
|
||||
}
|
||||
|
||||
ttl := time.Duration(seconds) * time.Second
|
||||
log.Infof("SSH JWT cache TTL set to %v via NB_SSH_JWT_CACHE_TTL", ttl)
|
||||
return ttl
|
||||
}
|
||||
|
||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||
func (s *Server) RequestJWTAuth(
|
||||
ctx context.Context,
|
||||
_ *proto.RequestJWTAuthRequest,
|
||||
) (*proto.RequestJWTAuthResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
config := s.config
|
||||
s.mutex.Unlock()
|
||||
|
||||
if config == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
|
||||
}
|
||||
|
||||
jwtCacheTTL := getJWTCacheTTL()
|
||||
if jwtCacheTTL > 0 {
|
||||
if cachedToken, found := s.jwtCache.get(); found {
|
||||
log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
|
||||
|
||||
return &proto.RequestJWTAuthResponse{
|
||||
CachedToken: cachedToken,
|
||||
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
isDesktop := isUnixRunningDesktop()
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
|
||||
}
|
||||
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err)
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
s.oauthAuthFlow.flow = oAuthFlow
|
||||
s.oauthAuthFlow.info = authInfo
|
||||
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second)
|
||||
s.mutex.Unlock()
|
||||
|
||||
return &proto.RequestJWTAuthResponse{
|
||||
VerificationURI: authInfo.VerificationURI,
|
||||
VerificationURIComplete: authInfo.VerificationURIComplete,
|
||||
UserCode: authInfo.UserCode,
|
||||
DeviceCode: authInfo.DeviceCode,
|
||||
ExpiresIn: int64(authInfo.ExpiresIn),
|
||||
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
func (s *Server) WaitJWTToken(
|
||||
ctx context.Context,
|
||||
req *proto.WaitJWTTokenRequest,
|
||||
) (*proto.WaitJWTTokenResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
oAuthFlow := s.oauthAuthFlow.flow
|
||||
authInfo := s.oauthAuthFlow.info
|
||||
s.mutex.Unlock()
|
||||
|
||||
if oAuthFlow == nil || authInfo.DeviceCode != req.DeviceCode {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active auth flow")
|
||||
}
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to get token: %v", err)
|
||||
}
|
||||
|
||||
token := tokenInfo.GetTokenToUse()
|
||||
|
||||
jwtCacheTTL := getJWTCacheTTL()
|
||||
if jwtCacheTTL > 0 {
|
||||
s.jwtCache.store(token, jwtCacheTTL)
|
||||
log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
|
||||
} else {
|
||||
log.Debug("JWT caching disabled, not storing token")
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
s.oauthAuthFlow = oauthAuthFlow{}
|
||||
s.mutex.Unlock()
|
||||
return &proto.WaitJWTTokenResponse{
|
||||
Token: tokenInfo.GetTokenToUse(),
|
||||
TokenType: tokenInfo.TokenType,
|
||||
ExpiresIn: int64(tokenInfo.ExpiresIn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
return false
|
||||
}
|
||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||
}
|
||||
|
||||
func (s *Server) runProbes() {
|
||||
if s.connectClient == nil {
|
||||
return
|
||||
@@ -1191,13 +1344,18 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
enableSSHRemotePortForwarding = *s.config.EnableSSHRemotePortForwarding
|
||||
}
|
||||
|
||||
disableSSHAuth := false
|
||||
if s.config.DisableSSHAuth != nil {
|
||||
disableSSHAuth = *s.config.DisableSSHAuth
|
||||
}
|
||||
|
||||
return &proto.GetConfigResponse{
|
||||
ManagementUrl: managementURL.String(),
|
||||
PreSharedKey: preSharedKey,
|
||||
AdminURL: adminURL.String(),
|
||||
InterfaceName: cfg.WgIface,
|
||||
WireguardPort: int64(cfg.WgPort),
|
||||
Mtu: int64(cfg.MTU),
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
@@ -1214,6 +1372,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
EnableSSHSFTP: enableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
|
||||
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: disableSSHAuth,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh/config"
|
||||
)
|
||||
|
||||
func registerStates(mgr *statemanager.Manager) {
|
||||
mgr.RegisterState(&dns.ShutdownState{})
|
||||
mgr.RegisterState(&systemops.ShutdownState{})
|
||||
mgr.RegisterState(&config.ShutdownState{})
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh/config"
|
||||
)
|
||||
|
||||
func registerStates(mgr *statemanager.Manager) {
|
||||
@@ -15,4 +16,5 @@ func registerStates(mgr *statemanager.Manager) {
|
||||
mgr.RegisterState(&systemops.ShutdownState{})
|
||||
mgr.RegisterState(&nftables.ShutdownState{})
|
||||
mgr.RegisterState(&iptables.ShutdownState{})
|
||||
mgr.RegisterState(&config.ShutdownState{})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user