mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
169 lines
4.8 KiB
Go
169 lines
4.8 KiB
Go
package server
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gliderlabs/ssh"
|
|
log "github.com/sirupsen/logrus"
|
|
cryptossh "golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
// sessionHandler handles SSH sessions
|
|
func (s *Server) sessionHandler(session ssh.Session) {
|
|
sessionKey := s.registerSession(session)
|
|
|
|
key := newAuthKey(session.User(), session.RemoteAddr())
|
|
s.mu.Lock()
|
|
jwtUsername := s.pendingAuthJWT[key]
|
|
if jwtUsername != "" {
|
|
s.sessionJWTUsers[sessionKey] = jwtUsername
|
|
delete(s.pendingAuthJWT, key)
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
logger := log.WithField("session", sessionKey)
|
|
if jwtUsername != "" {
|
|
logger = logger.WithField("jwt_user", jwtUsername)
|
|
logger.Infof("SSH session started (JWT user: %s)", jwtUsername)
|
|
} else {
|
|
logger.Infof("SSH session started")
|
|
}
|
|
sessionStart := time.Now()
|
|
|
|
defer s.unregisterSession(sessionKey, session)
|
|
defer func() {
|
|
duration := time.Since(sessionStart).Round(time.Millisecond)
|
|
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
|
logger.Warnf("close session after %v: %v", duration, err)
|
|
}
|
|
logger.Infof("SSH session closed after %v", duration)
|
|
}()
|
|
|
|
privilegeResult, err := s.userPrivilegeCheck(session.User())
|
|
if err != nil {
|
|
s.handlePrivError(logger, session, err)
|
|
return
|
|
}
|
|
|
|
ptyReq, winCh, isPty := session.Pty()
|
|
hasCommand := len(session.Command()) > 0
|
|
|
|
switch {
|
|
case isPty && hasCommand:
|
|
// ssh -t <host> <cmd> - Pty command execution
|
|
s.handleCommand(logger, session, privilegeResult, winCh)
|
|
case isPty:
|
|
// ssh <host> - Pty interactive session (login)
|
|
s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
|
|
case hasCommand:
|
|
// ssh <host> <cmd> - non-Pty command execution
|
|
s.handleCommand(logger, session, privilegeResult, nil)
|
|
default:
|
|
s.rejectInvalidSession(logger, session)
|
|
}
|
|
}
|
|
|
|
func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) {
|
|
if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil {
|
|
logger.Debugf(errWriteSession, err)
|
|
}
|
|
if err := session.Exit(1); err != nil {
|
|
logSessionExitError(logger, err)
|
|
}
|
|
logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr())
|
|
}
|
|
|
|
func (s *Server) registerSession(session ssh.Session) SessionKey {
|
|
sessionID := session.Context().Value(ssh.ContextKeySessionID)
|
|
if sessionID == nil {
|
|
sessionID = fmt.Sprintf("%p", session)
|
|
}
|
|
|
|
// Create a short 4-byte identifier from the full session ID
|
|
hasher := sha256.New()
|
|
hasher.Write([]byte(fmt.Sprintf("%v", sessionID)))
|
|
hash := hasher.Sum(nil)
|
|
shortID := hex.EncodeToString(hash[:4])
|
|
|
|
remoteAddr := session.RemoteAddr().String()
|
|
username := session.User()
|
|
sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID))
|
|
|
|
s.mu.Lock()
|
|
s.sessions[sessionKey] = session
|
|
s.mu.Unlock()
|
|
|
|
return sessionKey
|
|
}
|
|
|
|
func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) {
|
|
s.mu.Lock()
|
|
delete(s.sessions, sessionKey)
|
|
delete(s.sessionJWTUsers, sessionKey)
|
|
|
|
// Cancel all port forwarding connections for this session
|
|
var connectionsToCancel []ConnectionKey
|
|
for key := range s.sessionCancels {
|
|
if strings.HasPrefix(string(key), string(sessionKey)+"-") {
|
|
connectionsToCancel = append(connectionsToCancel, key)
|
|
}
|
|
}
|
|
|
|
for _, key := range connectionsToCancel {
|
|
if cancelFunc, exists := s.sessionCancels[key]; exists {
|
|
log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key)
|
|
cancelFunc()
|
|
delete(s.sessionCancels, key)
|
|
}
|
|
}
|
|
|
|
if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil {
|
|
if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok {
|
|
delete(s.sshConnections, sshConn)
|
|
}
|
|
}
|
|
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) {
|
|
logger.Warnf("user privilege check failed: %v", err)
|
|
|
|
errorMsg := s.buildUserLookupErrorMessage(err)
|
|
|
|
if _, writeErr := fmt.Fprint(session, errorMsg); writeErr != nil {
|
|
logger.Debugf(errWriteSession, writeErr)
|
|
}
|
|
if exitErr := session.Exit(1); exitErr != nil {
|
|
logSessionExitError(logger, exitErr)
|
|
}
|
|
}
|
|
|
|
// buildUserLookupErrorMessage creates appropriate user-facing error messages based on error type
|
|
func (s *Server) buildUserLookupErrorMessage(err error) string {
|
|
var privilegedErr *PrivilegedUserError
|
|
|
|
switch {
|
|
case errors.As(err, &privilegedErr):
|
|
if privilegedErr.Username == "root" {
|
|
return "root login is disabled on this SSH server\n"
|
|
}
|
|
return "privileged user access is disabled on this SSH server\n"
|
|
|
|
case errors.Is(err, ErrPrivilegeRequired):
|
|
return "Windows user switching failed - NetBird must run with elevated privileges for user switching\n"
|
|
|
|
case errors.Is(err, ErrPrivilegedUserSwitch):
|
|
return "Cannot switch to privileged user - current user lacks required privileges\n"
|
|
|
|
default:
|
|
return "User authentication failed\n"
|
|
}
|
|
}
|