[client] Add port forwarding to ssh proxy (#5031)

* Implement port forwarding for the ssh proxy

* Allow user switching for port forwarding
This commit is contained in:
Viktor Liu
2026-01-07 12:18:04 +08:00
committed by GitHub
parent 7142d45ef3
commit f012fb8592
15 changed files with 1006 additions and 370 deletions

View File

@@ -98,19 +98,17 @@ func (a *Authorizer) Update(config *Config) {
len(config.AuthorizedUsers), len(machineUsers))
}
// Authorize validates if a user is authorized to login as the specified OS user
// Returns nil if authorized, or an error describing why authorization failed
func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
// Authorize validates if a user is authorized to login as the specified OS user.
// Returns a success message describing how authorization was granted, or an error.
func (a *Authorizer) Authorize(jwtUserID, osUsername string) (string, error) {
if jwtUserID == "" {
log.Warnf("SSH auth denied: JWT user ID is empty for OS user '%s'", osUsername)
return ErrEmptyUserID
return "", fmt.Errorf("JWT user ID is empty for OS user %q: %w", osUsername, ErrEmptyUserID)
}
// Hash the JWT user ID for comparison
hashedUserID, err := sshuserhash.HashUserID(jwtUserID)
if err != nil {
log.Errorf("SSH auth denied: failed to hash user ID '%s' for OS user '%s': %v", jwtUserID, osUsername, err)
return fmt.Errorf("failed to hash user ID: %w", err)
return "", fmt.Errorf("hash user ID %q for OS user %q: %w", jwtUserID, osUsername, err)
}
a.mu.RLock()
@@ -119,8 +117,7 @@ func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
// Find the index of this user in the authorized list
userIndex, found := a.findUserIndex(hashedUserID)
if !found {
log.Warnf("SSH auth denied: user '%s' (hash: %s) not in authorized list for OS user '%s'", jwtUserID, hashedUserID, osUsername)
return ErrUserNotAuthorized
return "", fmt.Errorf("user %q (hash: %s) not in authorized list for OS user %q: %w", jwtUserID, hashedUserID, osUsername, ErrUserNotAuthorized)
}
return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex)
@@ -128,12 +125,11 @@ func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
// checkMachineUserMapping validates if a user's index is authorized for the specified OS user
// Checks wildcard mapping first, then specific OS user mappings
func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) error {
func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) (string, error) {
// If wildcard exists and user's index is in the wildcard list, allow access to any OS user
if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard {
if a.isIndexInList(uint32(userIndex), wildcardIndexes) {
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' via wildcard (index: %d)", jwtUserID, osUsername, userIndex)
return nil
return fmt.Sprintf("granted via wildcard (index: %d)", userIndex), nil
}
}
@@ -141,18 +137,15 @@ func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userI
allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername]
if !hasMachineUserMapping {
// No mapping for this OS user - deny by default (fail closed)
log.Warnf("SSH auth denied: no machine user mapping for OS user '%s' (JWT user: %s)", osUsername, jwtUserID)
return ErrNoMachineUserMapping
return "", fmt.Errorf("no machine user mapping for OS user %q (JWT user: %s): %w", osUsername, jwtUserID, ErrNoMachineUserMapping)
}
// Check if user's index is in the allowed indexes for this specific OS user
if !a.isIndexInList(uint32(userIndex), allowedIndexes) {
log.Warnf("SSH auth denied: user '%s' not mapped to OS user '%s' (user index: %d)", jwtUserID, osUsername, userIndex)
return ErrUserNotMappedToOSUser
return "", fmt.Errorf("user %q not mapped to OS user %q (index: %d): %w", jwtUserID, osUsername, userIndex, ErrUserNotMappedToOSUser)
}
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' (index: %d)", jwtUserID, osUsername, userIndex)
return nil
return fmt.Sprintf("granted (index: %d)", userIndex), nil
}
// GetUserIDClaim returns the JWT claim name used to extract user IDs

View File

@@ -24,7 +24,7 @@ func TestAuthorizer_Authorize_UserNotInList(t *testing.T) {
authorizer.Update(config)
// Try to authorize a different user
err = authorizer.Authorize("unauthorized-user", "root")
_, err = authorizer.Authorize("unauthorized-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
@@ -45,15 +45,15 @@ func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T)
authorizer.Update(config)
// All attempts should fail when no machine user mappings exist (fail closed)
err = authorizer.Authorize("user1", "root")
_, err = authorizer.Authorize("user1", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "admin")
_, err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user1", "postgres")
_, err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
@@ -80,21 +80,21 @@ func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testi
authorizer.Update(config)
// user1 (index 0) should access root and admin
err = authorizer.Authorize("user1", "root")
_, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user1", "admin")
_, err = authorizer.Authorize("user1", "admin")
assert.NoError(t, err)
// user2 (index 1) should access root and postgres
err = authorizer.Authorize("user2", "root")
_, err = authorizer.Authorize("user2", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "postgres")
_, err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
// user3 (index 2) should access postgres
err = authorizer.Authorize("user3", "postgres")
_, err = authorizer.Authorize("user3", "postgres")
assert.NoError(t, err)
}
@@ -121,22 +121,22 @@ func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testin
authorizer.Update(config)
// user1 (index 0) should NOT access postgres
err = authorizer.Authorize("user1", "postgres")
_, err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user2 (index 1) should NOT access admin
err = authorizer.Authorize("user2", "admin")
_, err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user3 (index 2) should NOT access root
err = authorizer.Authorize("user3", "root")
_, err = authorizer.Authorize("user3", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user3 (index 2) should NOT access admin
err = authorizer.Authorize("user3", "admin")
_, err = authorizer.Authorize("user3", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
}
@@ -158,7 +158,7 @@ func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) {
authorizer.Update(config)
// user1 should NOT access an unmapped OS user (fail closed)
err = authorizer.Authorize("user1", "postgres")
_, err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
@@ -178,7 +178,7 @@ func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) {
authorizer.Update(config)
// Empty user ID should fail
err = authorizer.Authorize("", "root")
_, err = authorizer.Authorize("", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrEmptyUserID)
}
@@ -211,12 +211,12 @@ func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) {
// All users should be authorized for root
for i := 0; i < 10; i++ {
err := authorizer.Authorize("user"+string(rune('0'+i)), "root")
_, err := authorizer.Authorize("user"+string(rune('0'+i)), "root")
assert.NoError(t, err, "user%d should be authorized", i)
}
// User not in list should fail
err := authorizer.Authorize("unknown-user", "root")
_, err := authorizer.Authorize("unknown-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
@@ -236,14 +236,14 @@ func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) {
authorizer.Update(config)
// user1 should be authorized
err = authorizer.Authorize("user1", "root")
_, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// Clear configuration
authorizer.Update(nil)
// user1 should no longer be authorized
err = authorizer.Authorize("user1", "root")
_, err = authorizer.Authorize("user1", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
@@ -267,16 +267,16 @@ func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) {
authorizer.Update(config)
// root should work
err = authorizer.Authorize("user1", "root")
_, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// postgres should fail (no mapping)
err = authorizer.Authorize("user1", "postgres")
_, err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
// admin should fail (no mapping)
err = authorizer.Authorize("user1", "admin")
_, err = authorizer.Authorize("user1", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
@@ -301,7 +301,7 @@ func TestAuthorizer_CustomUserIDClaim(t *testing.T) {
assert.Equal(t, "email", authorizer.GetUserIDClaim())
// Authorize with email as user ID
err = authorizer.Authorize("user@example.com", "root")
_, err = authorizer.Authorize("user@example.com", "root")
assert.NoError(t, err)
}
@@ -349,19 +349,19 @@ func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) {
authorizer.Update(config)
// First user should have access
err := authorizer.Authorize("user"+string(rune(0)), "root")
_, err := authorizer.Authorize("user"+string(rune(0)), "root")
assert.NoError(t, err)
// Middle user should have access
err = authorizer.Authorize("user"+string(rune(500)), "root")
_, err = authorizer.Authorize("user"+string(rune(500)), "root")
assert.NoError(t, err)
// Last user should have access
err = authorizer.Authorize("user"+string(rune(999)), "root")
_, err = authorizer.Authorize("user"+string(rune(999)), "root")
assert.NoError(t, err)
// User not in mapping should NOT have access
err = authorizer.Authorize("user"+string(rune(100)), "root")
_, err = authorizer.Authorize("user"+string(rune(100)), "root")
assert.Error(t, err)
}
@@ -393,7 +393,7 @@ func TestAuthorizer_ConcurrentAuthorization(t *testing.T) {
if idx%2 == 0 {
user = "user2"
}
err := authorizer.Authorize(user, "root")
_, err := authorizer.Authorize(user, "root")
errChan <- err
}(i)
}
@@ -426,22 +426,22 @@ func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) {
authorizer.Update(config)
// All authorized users should be able to access any OS user
err = authorizer.Authorize("user1", "root")
_, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "postgres")
_, err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
err = authorizer.Authorize("user3", "admin")
_, err = authorizer.Authorize("user3", "admin")
assert.NoError(t, err)
err = authorizer.Authorize("user1", "ubuntu")
_, err = authorizer.Authorize("user1", "ubuntu")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "nginx")
_, err = authorizer.Authorize("user2", "nginx")
assert.NoError(t, err)
err = authorizer.Authorize("user3", "docker")
_, err = authorizer.Authorize("user3", "docker")
assert.NoError(t, err)
}
@@ -462,11 +462,11 @@ func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) {
authorizer.Update(config)
// user1 should have access
err = authorizer.Authorize("user1", "root")
_, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// Unauthorized user should still be denied even with wildcard
err = authorizer.Authorize("unauthorized-user", "root")
_, err = authorizer.Authorize("unauthorized-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
@@ -492,17 +492,17 @@ func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) {
authorizer.Update(config)
// Both users should be able to access root via wildcard (takes precedence over specific mapping)
err = authorizer.Authorize("user1", "root")
_, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "root")
_, err = authorizer.Authorize("user2", "root")
assert.NoError(t, err)
// Both users should be able to access any other OS user via wildcard
err = authorizer.Authorize("user1", "postgres")
_, err = authorizer.Authorize("user1", "postgres")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "admin")
_, err = authorizer.Authorize("user2", "admin")
assert.NoError(t, err)
}
@@ -526,29 +526,29 @@ func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) {
authorizer.Update(config)
// user1 can access root
err = authorizer.Authorize("user1", "root")
_, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// user2 can access postgres
err = authorizer.Authorize("user2", "postgres")
_, err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
// user1 cannot access postgres
err = authorizer.Authorize("user1", "postgres")
_, err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user2 cannot access root
err = authorizer.Authorize("user2", "root")
_, err = authorizer.Authorize("user2", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// Neither can access unmapped OS users
err = authorizer.Authorize("user1", "admin")
_, err = authorizer.Authorize("user1", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "admin")
_, err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
@@ -578,35 +578,35 @@ func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
authorizer.Update(config)
// wasm (index 0) should access any OS user via wildcard
err = authorizer.Authorize("wasm", "root")
_, err = authorizer.Authorize("wasm", "root")
assert.NoError(t, err, "wasm should access root via wildcard")
err = authorizer.Authorize("wasm", "alice")
_, err = authorizer.Authorize("wasm", "alice")
assert.NoError(t, err, "wasm should access alice via wildcard")
err = authorizer.Authorize("wasm", "bob")
_, err = authorizer.Authorize("wasm", "bob")
assert.NoError(t, err, "wasm should access bob via wildcard")
err = authorizer.Authorize("wasm", "postgres")
_, err = authorizer.Authorize("wasm", "postgres")
assert.NoError(t, err, "wasm should access postgres via wildcard")
// user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres
err = authorizer.Authorize("user2", "alice")
_, err = authorizer.Authorize("user2", "alice")
assert.NoError(t, err, "user2 should access alice via explicit mapping")
err = authorizer.Authorize("user2", "bob")
_, err = authorizer.Authorize("user2", "bob")
assert.NoError(t, err, "user2 should access bob via explicit mapping")
err = authorizer.Authorize("user2", "root")
_, err = authorizer.Authorize("user2", "root")
assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)")
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "postgres")
_, err = authorizer.Authorize("user2", "postgres")
assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)")
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
// Unauthorized user should still be denied
err = authorizer.Authorize("user3", "root")
_, err = authorizer.Authorize("user3", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
@@ -551,14 +550,15 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
defer func() {
if err := localConn.Close(); err != nil {
log.Debugf("local connection close error: %v", err)
log.Debugf("local port forwarding: close local connection: %v", err)
}
}()
channel, err := c.client.Dial("tcp", remoteAddr)
if err != nil {
if strings.Contains(err.Error(), "administratively prohibited") {
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n")
var openErr *ssh.OpenChannelError
if errors.As(err, &openErr) && openErr.Reason == ssh.Prohibited {
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: port forwarding is disabled\n")
} else {
log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err)
}
@@ -566,19 +566,11 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
}
defer func() {
if err := channel.Close(); err != nil {
log.Debugf("remote channel close error: %v", err)
log.Debugf("local port forwarding: close remote channel: %v", err)
}
}()
go func() {
if _, err := io.Copy(channel, localConn); err != nil {
log.Debugf("local forward copy error (local->remote): %v", err)
}
}()
if _, err := io.Copy(localConn, channel); err != nil {
log.Debugf("local forward copy error (remote->local): %v", err)
}
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
}
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
@@ -633,7 +625,7 @@ func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error {
return fmt.Errorf("send tcpip-forward request: %w", err)
}
if !ok {
return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)")
return fmt.Errorf("remote port forwarding denied by server")
}
return nil
}
@@ -676,7 +668,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
}
defer func() {
if err := channel.Close(); err != nil {
log.Debugf("remote channel close error: %v", err)
log.Debugf("remote port forwarding: close remote channel: %v", err)
}
}()
@@ -688,19 +680,11 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
}
defer func() {
if err := localConn.Close(); err != nil {
log.Debugf("local connection close error: %v", err)
log.Debugf("remote port forwarding: close local connection: %v", err)
}
}()
go func() {
if _, err := io.Copy(localConn, channel); err != nil {
log.Debugf("remote forward copy error (remote->local): %v", err)
}
}()
if _, err := io.Copy(channel, localConn); err != nil {
log.Debugf("remote forward copy error (local->remote): %v", err)
}
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
}
// tcpipForwardMsg represents the structure for tcpip-forward requests

View File

@@ -193,3 +193,64 @@ func buildAddressList(hostname string, remote net.Addr) []string {
}
return addresses
}
// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections.
// It waits for both directions to complete before returning.
// The caller is responsible for closing the connections.
func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (1->2): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (2->1): %v", err)
}
done <- struct{}{}
}()
<-done
<-done
}
func isExpectedCopyError(err error) bool {
return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled)
}
// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections.
// It waits for both directions to complete or for context cancellation before returning.
// Both connections are closed when the function returns.
func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (1->2): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (2->1): %v", err)
}
done <- struct{}{}
}()
select {
case <-ctx.Done():
case <-done:
select {
case <-ctx.Done():
case <-done:
}
}
_ = conn1.Close()
_ = conn2.Close()
}

View File

@@ -2,6 +2,7 @@ package proxy
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
@@ -42,6 +43,14 @@ type SSHProxy struct {
conn *grpc.ClientConn
daemonClient proto.DaemonServiceClient
browserOpener func(string) error
mu sync.RWMutex
backendClient *cryptossh.Client
// jwtToken is set once in runProxySSHServer before any handlers are called,
// so concurrent access is safe without additional synchronization.
jwtToken string
forwardedChannelsOnce sync.Once
}
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) {
@@ -63,6 +72,17 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browse
}
func (p *SSHProxy) Close() error {
p.mu.Lock()
backendClient := p.backendClient
p.backendClient = nil
p.mu.Unlock()
if backendClient != nil {
if err := backendClient.Close(); err != nil {
log.Debugf("close backend client: %v", err)
}
}
if p.conn != nil {
return p.conn.Close()
}
@@ -77,16 +97,16 @@ func (p *SSHProxy) Connect(ctx context.Context) error {
return fmt.Errorf(jwtAuthErrorMsg, err)
}
return p.runProxySSHServer(ctx, jwtToken)
log.Debugf("JWT authentication successful, starting proxy to %s:%d", p.targetHost, p.targetPort)
return p.runProxySSHServer(jwtToken)
}
func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
p.jwtToken = jwtToken
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
sshServer := &ssh.Server{
Handler: func(s ssh.Session) {
p.handleSSHSession(ctx, s, jwtToken)
},
Handler: p.handleSSHSession,
ChannelHandlers: map[string]ssh.ChannelHandler{
"session": ssh.DefaultSessionHandler,
"direct-tcpip": p.directTCPIPHandler,
@@ -119,15 +139,20 @@ func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error
return nil
}
func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
func (p *SSHProxy) handleSSHSession(session ssh.Session) {
ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0
sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
return
}
defer func() { _ = sshClient.Close() }()
if !isPty && !hasCommand {
p.handleNonInteractiveSession(session, sshClient)
return
}
serverSession, err := sshClient.NewSession()
if err != nil {
@@ -140,7 +165,6 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
serverSession.Stdout = session
serverSession.Stderr = session.Stderr()
ptyReq, winCh, isPty := session.Pty()
if isPty {
if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil {
log.Debugf("PTY request to backend: %v", err)
@@ -155,7 +179,7 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
}()
}
if len(session.Command()) > 0 {
if hasCommand {
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
log.Debugf("run command: %v", err)
p.handleProxyExitCode(session, err)
@@ -176,12 +200,29 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
var exitErr *cryptossh.ExitError
if errors.As(err, &exitErr) {
if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil {
log.Debugf("set exit status: %v", exitErr)
if err := session.Exit(exitErr.ExitStatus()); err != nil {
log.Debugf("set exit status: %v", err)
}
}
}
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
// Create a backend session to mirror the client's session request.
// This keeps the connection alive on the server side while port forwarding channels operate.
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
return
}
defer func() { _ = serverSession.Close() }()
<-session.Context().Done()
if err := session.Exit(0); err != nil {
log.Debugf("session exit: %v", err)
}
}
func generateHostKey() (ssh.Signer, error) {
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
@@ -250,8 +291,52 @@ func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
return nil
}
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
_ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
// directTCPIPHandler handles local port forwarding (direct-tcpip channel).
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, sshCtx ssh.Context) {
var payload struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
_, _ = fmt.Fprintf(p.stderr, "parse direct-tcpip payload: %v\n", err)
_ = newChan.Reject(cryptossh.ConnectionFailed, "invalid payload")
return
}
dest := fmt.Sprintf("%s:%d", payload.DestAddr, payload.DestPort)
log.Debugf("local port forwarding: %s", dest)
backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "backend connection for port forwarding: %v\n", err)
_ = newChan.Reject(cryptossh.ConnectionFailed, "backend connection failed")
return
}
backendChan, backendReqs, err := backendClient.OpenChannel("direct-tcpip", newChan.ExtraData())
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "open backend channel for %s: %v\n", dest, err)
var openErr *cryptossh.OpenChannelError
if errors.As(err, &openErr) {
_ = newChan.Reject(openErr.Reason, openErr.Message)
} else {
_ = newChan.Reject(cryptossh.ConnectionFailed, err.Error())
}
return
}
go cryptossh.DiscardRequests(backendReqs)
clientChan, clientReqs, err := newChan.Accept()
if err != nil {
log.Debugf("local port forwarding: accept channel: %v", err)
_ = backendChan.Close()
return
}
go cryptossh.DiscardRequests(clientReqs)
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
}
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
@@ -354,12 +439,143 @@ func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.Wr
}
}
func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
return false, []byte("port forwarding not supported in proxy")
// tcpipForwardHandler handles remote port forwarding (tcpip-forward request).
func (p *SSHProxy) tcpipForwardHandler(sshCtx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
var reqPayload struct {
Host string
Port uint32
}
if err := cryptossh.Unmarshal(req.Payload, &reqPayload); err != nil {
_, _ = fmt.Fprintf(p.stderr, "parse tcpip-forward payload: %v\n", err)
return false, nil
}
log.Debugf("tcpip-forward request for %s:%d", reqPayload.Host, reqPayload.Port)
backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "backend connection for remote port forwarding: %v\n", err)
return false, nil
}
ok, payload, err := backendClient.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "forward tcpip-forward request for %s:%d: %v\n", reqPayload.Host, reqPayload.Port, err)
return false, nil
}
if ok {
actualPort := reqPayload.Port
if reqPayload.Port == 0 && len(payload) >= 4 {
actualPort = binary.BigEndian.Uint32(payload)
}
log.Debugf("remote port forwarding established for %s:%d", reqPayload.Host, actualPort)
p.forwardedChannelsOnce.Do(func() {
go p.handleForwardedChannels(sshCtx, backendClient)
})
}
return ok, payload
}
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
return true, nil
// cancelTcpipForwardHandler handles cancel-tcpip-forward request.
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
var reqPayload struct {
Host string
Port uint32
}
if err := cryptossh.Unmarshal(req.Payload, &reqPayload); err != nil {
_, _ = fmt.Fprintf(p.stderr, "parse cancel-tcpip-forward payload: %v\n", err)
return false, nil
}
log.Debugf("cancel-tcpip-forward request for %s:%d", reqPayload.Host, reqPayload.Port)
backendClient := p.getBackendClient()
if backendClient == nil {
return false, nil
}
ok, payload, err := backendClient.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "cancel-tcpip-forward for %s:%d: %v\n", reqPayload.Host, reqPayload.Port, err)
return false, nil
}
return ok, payload
}
// getOrCreateBackendClient returns the existing backend client or creates a new one.
func (p *SSHProxy) getOrCreateBackendClient(ctx context.Context, user string) (*cryptossh.Client, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.backendClient != nil {
return p.backendClient, nil
}
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
log.Debugf("connecting to backend %s", targetAddr)
client, err := p.dialBackend(ctx, targetAddr, user, p.jwtToken)
if err != nil {
return nil, err
}
log.Debugf("backend connection established to %s", targetAddr)
p.backendClient = client
return client, nil
}
// getBackendClient returns the existing backend client or nil.
func (p *SSHProxy) getBackendClient() *cryptossh.Client {
p.mu.RLock()
defer p.mu.RUnlock()
return p.backendClient
}
// handleForwardedChannels handles forwarded-tcpip channels from the backend for remote port forwarding.
// When the backend receives incoming connections on the forwarded port, it sends them as
// "forwarded-tcpip" channels which we need to proxy to the client.
func (p *SSHProxy) handleForwardedChannels(sshCtx ssh.Context, backendClient *cryptossh.Client) {
sshConn, ok := sshCtx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
if !ok || sshConn == nil {
log.Debugf("no SSH connection in context for forwarded channels")
return
}
channelChan := backendClient.HandleChannelOpen("forwarded-tcpip")
for {
select {
case <-sshCtx.Done():
return
case newChannel, ok := <-channelChan:
if !ok {
return
}
go p.handleForwardedChannel(sshCtx, sshConn, newChannel)
}
}
}
// handleForwardedChannel handles a single forwarded-tcpip channel from the backend.
func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh.ServerConn, newChannel cryptossh.NewChannel) {
backendChan, backendReqs, err := newChannel.Accept()
if err != nil {
log.Debugf("remote port forwarding: accept from backend: %v", err)
return
}
go cryptossh.DiscardRequests(backendReqs)
clientChan, clientReqs, err := sshConn.OpenChannel("forwarded-tcpip", newChannel.ExtraData())
if err != nil {
log.Debugf("remote port forwarding: open to client: %v", err)
_ = backendChan.Close()
return
}
go cryptossh.DiscardRequests(clientReqs)
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
}
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {

View File

@@ -1,25 +1,32 @@
// Package server implements port forwarding for the SSH server.
//
// Security note: Port forwarding runs in the main server process without privilege separation.
// The attack surface is primarily io.Copy through well-tested standard library code, making it
// lower risk than shell execution which uses privilege-separated child processes. We enforce
// user-level port restrictions: non-privileged users cannot bind to ports < 1024.
package server
import (
"encoding/binary"
"fmt"
"io"
"net"
"runtime"
"strconv"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
)
// SessionKey uniquely identifies an SSH session
type SessionKey string
const privilegedPortThreshold = 1024
// ConnectionKey uniquely identifies a port forwarding connection within a session
type ConnectionKey string
// sessionKey uniquely identifies an SSH session
type sessionKey string
// ForwardKey uniquely identifies a port forwarding listener
type ForwardKey string
// forwardKey uniquely identifies a port forwarding listener
type forwardKey string
// tcpipForwardMsg represents the structure for tcpip-forward SSH requests
type tcpipForwardMsg struct {
@@ -47,34 +54,32 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
allowRemote := s.allowRemotePortForwarding
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
logger := s.getRequestLogger(ctx)
if !allowLocal {
log.Warnf("local port forwarding denied for %s from %s: disabled by configuration",
net.JoinHostPort(dstHost, fmt.Sprintf("%d", dstPort)), ctx.RemoteAddr())
logger.Warnf("local port forwarding denied for %s:%d: disabled", dstHost, dstPort)
return false
}
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
log.Warnf("local port forwarding denied for %s:%d from %s: %v", dstHost, dstPort, ctx.RemoteAddr(), err)
logger.Warnf("local port forwarding denied for %s:%d: %v", dstHost, dstPort, err)
return false
}
log.Debugf("local port forwarding allowed: %s:%d", dstHost, dstPort)
return true
}
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
logger := s.getRequestLogger(ctx)
if !allowRemote {
log.Warnf("remote port forwarding denied for %s from %s: disabled by configuration",
net.JoinHostPort(bindHost, fmt.Sprintf("%d", bindPort)), ctx.RemoteAddr())
logger.Warnf("remote port forwarding denied for %s:%d: disabled", bindHost, bindPort)
return false
}
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
log.Warnf("remote port forwarding denied for %s:%d from %s: %v", bindHost, bindPort, ctx.RemoteAddr(), err)
logger.Warnf("remote port forwarding denied for %s:%d: %v", bindHost, bindPort, err)
return false
}
log.Debugf("remote port forwarding allowed: %s:%d", bindHost, bindPort)
return true
}
@@ -82,23 +87,20 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
}
// checkPortForwardingPrivileges validates privilege requirements for port forwarding operations.
// Returns nil if allowed, error if denied.
// For remote port forwarding (binding), it enforces that non-privileged users cannot bind to
// ports below 1024, mirroring the restriction they would face if binding directly.
//
// Note: FeatureSupportsUserSwitch is true because we accept requests from any authenticated user,
// though we don't actually switch users - port forwarding runs in the server process. The resolved
// user is used for privileged port access checks.
func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error {
if ctx == nil {
return fmt.Errorf("%s port forwarding denied: no context", forwardType)
}
username := ctx.User()
remoteAddr := "unknown"
if ctx.RemoteAddr() != nil {
remoteAddr = ctx.RemoteAddr().String()
}
logger := log.WithFields(log.Fields{"user": username, "remote": remoteAddr, "port": port})
result := s.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: username,
FeatureSupportsUserSwitch: false,
RequestedUsername: ctx.User(),
FeatureSupportsUserSwitch: true,
FeatureName: forwardType + " port forwarding",
})
@@ -106,12 +108,42 @@ func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType stri
return result.Error
}
logger.Debugf("%s port forwarding allowed: user %s validated (port %d)",
forwardType, result.User.Username, port)
if err := s.checkPrivilegedPortAccess(forwardType, port, result); err != nil {
return err
}
return nil
}
// checkPrivilegedPortAccess enforces that non-privileged users cannot bind to privileged ports.
// This applies to remote port forwarding where the server binds a port on behalf of the user.
// On Windows, there is no privileged port restriction, so this check is skipped.
func (s *Server) checkPrivilegedPortAccess(forwardType string, port uint32, result PrivilegeCheckResult) error {
if runtime.GOOS == "windows" {
return nil
}
isBindOperation := forwardType == "remote" || forwardType == "tcpip-forward"
if !isBindOperation {
return nil
}
// Port 0 means "pick any available port", which will be >= 1024
if port == 0 || port >= privilegedPortThreshold {
return nil
}
if result.User != nil && isPrivilegedUsername(result.User.Username) {
return nil
}
username := "unknown"
if result.User != nil {
username = result.User.Username
}
return fmt.Errorf("user %s cannot bind to privileged port %d (requires root)", username, port)
}
// tcpipForwardHandler handles tcpip-forward requests for remote port forwarding.
func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
logger := s.getRequestLogger(ctx)
@@ -132,8 +164,6 @@ func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *crypto
return false, nil
}
logger.Debugf("tcpip-forward request: %s:%d", payload.Host, payload.Port)
sshConn, err := s.getSSHConnection(ctx)
if err != nil {
logger.Warnf("tcpip-forward request denied: %v", err)
@@ -153,8 +183,10 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *
return false, nil
}
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
if s.removeRemoteForwardListener(key) {
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, payload.Port)
s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr)
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
return true, nil
}
@@ -165,14 +197,11 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *
// handleRemoteForwardListener handles incoming connections for remote port forwarding.
func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) {
log.Debugf("starting remote forward listener handler for %s:%d", host, port)
logger := s.getRequestLogger(ctx)
defer func() {
log.Debugf("cleaning up remote forward listener for %s:%d", host, port)
if err := ln.Close(); err != nil {
log.Debugf("remote forward listener close error: %v", err)
} else {
log.Debugf("remote forward listener closed successfully for %s:%d", host, port)
logger.Debugf("remote forward listener close error for %s:%d: %v", host, port, err)
}
}()
@@ -196,28 +225,43 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h
select {
case result := <-acceptChan:
if result.err != nil {
log.Debugf("remote forward accept error: %v", result.err)
logger.Debugf("remote forward accept error: %v", result.err)
return
}
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
case <-ctx.Done():
log.Debugf("remote forward listener shutting down due to context cancellation for %s:%d", host, port)
logger.Debugf("remote forward listener shutting down for %s:%d", host, port)
return
}
}
}
// getRequestLogger creates a logger with user and remote address context
// getRequestLogger creates a logger with session/conn and jwt_user context
func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry {
remoteAddr := "unknown"
username := "unknown"
if ctx != nil {
if ctx.RemoteAddr() != nil {
remoteAddr = ctx.RemoteAddr().String()
sessionKey := s.findSessionKeyByContext(ctx)
s.mu.RLock()
defer s.mu.RUnlock()
if state, exists := s.sessions[sessionKey]; exists {
logger := log.WithField("session", sessionKey)
if state.jwtUsername != "" {
logger = logger.WithField("jwt_user", state.jwtUsername)
}
username = ctx.User()
return logger
}
return log.WithFields(log.Fields{"user": username, "remote": remoteAddr})
if ctx.RemoteAddr() != nil {
if connState, exists := s.connections[connKey(ctx.RemoteAddr().String())]; exists {
return s.connLogger(connState)
}
}
remoteAddr := "unknown"
if ctx.RemoteAddr() != nil {
remoteAddr = ctx.RemoteAddr().String()
}
return log.WithField("session", fmt.Sprintf("%s@%s", ctx.User(), remoteAddr))
}
// isRemotePortForwardingAllowed checks if remote port forwarding is enabled
@@ -227,6 +271,13 @@ func (s *Server) isRemotePortForwardingAllowed() bool {
return s.allowRemotePortForwarding
}
// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled
func (s *Server) isPortForwardingEnabled() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.allowLocalPortForwarding || s.allowRemotePortForwarding
}
// parseTcpipForwardRequest parses the SSH request payload
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
var payload tcpipForwardMsg
@@ -267,10 +318,11 @@ func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
}
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
s.storeRemoteForwardListener(key, ln)
s.markConnectionActivePortForward(sshConn, ctx.User(), ctx.RemoteAddr().String())
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, actualPort)
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
response := make([]byte, 4)
@@ -288,44 +340,34 @@ type acceptResult struct {
// handleRemoteForwardConnection handles a single remote port forwarding connection
func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) {
sessionKey := s.findSessionKeyByContext(ctx)
connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port)
logger := log.WithFields(log.Fields{
"session": sessionKey,
"conn": connID,
})
logger := s.getRequestLogger(ctx)
defer func() {
if err := conn.Close(); err != nil {
logger.Debugf("connection close error: %v", err)
}
}()
sshConn := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
if sshConn == nil {
sshConn, ok := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
if !ok || sshConn == nil {
logger.Debugf("remote forward: no SSH connection in context")
_ = conn.Close()
return
}
remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
if !ok {
logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr())
_ = conn.Close()
return
}
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger)
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr)
if err != nil {
logger.Debugf("open forward channel: %v", err)
logger.Debugf("open forward channel for %s:%d: %v", host, port, err)
_ = conn.Close()
return
}
s.proxyForwardConnection(ctx, logger, conn, channel)
nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel)
}
// openForwardChannel creates an SSH forwarded-tcpip channel
func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr, logger *log.Entry) (cryptossh.Channel, error) {
logger.Tracef("opening forwarded-tcpip channel for %s:%d", host, port)
func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr) (cryptossh.Channel, error) {
payload := struct {
ConnectedAddress string
ConnectedPort uint32
@@ -346,41 +388,3 @@ func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string,
go cryptossh.DiscardRequests(reqs)
return channel, nil
}
// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel
func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(channel, conn); err != nil {
logger.Debugf("copy error (conn->channel): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(conn, channel); err != nil {
logger.Debugf("copy error (channel->conn): %v", err)
}
done <- struct{}{}
}()
select {
case <-ctx.Done():
logger.Debugf("session ended, closing connections")
case <-done:
// First copy finished, wait for second copy or context cancellation
select {
case <-ctx.Done():
logger.Debugf("session ended, closing connections")
case <-done:
}
}
if err := channel.Close(); err != nil {
logger.Debugf("channel close error: %v", err)
}
if err := conn.Close(); err != nil {
logger.Debugf("connection close error: %v", err)
}
}

View File

@@ -9,6 +9,7 @@ import (
"io"
"net"
"net/netip"
"slices"
"strings"
"sync"
"time"
@@ -40,6 +41,11 @@ const (
msgPrivilegedUserDisabled = "privileged user login is disabled"
cmdInteractiveShell = "<interactive shell>"
cmdPortForwarding = "<port forwarding>"
cmdSFTP = "<sftp>"
cmdNonInteractive = "<idle>"
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
DefaultJWTMaxTokenAge = 5 * 60
)
@@ -90,10 +96,10 @@ func logSessionExitError(logger *log.Entry, err error) {
}
}
// safeLogCommand returns a safe representation of the command for logging
// safeLogCommand returns a safe representation of the command for logging.
func safeLogCommand(cmd []string) string {
if len(cmd) == 0 {
return "<interactive shell>"
return cmdInteractiveShell
}
if len(cmd) == 1 {
return cmd[0]
@@ -101,26 +107,50 @@ func safeLogCommand(cmd []string) string {
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
}
type sshConnectionState struct {
hasActivePortForward bool
username string
remoteAddr string
// connState tracks the state of an SSH connection for port forwarding and status display.
type connState struct {
username string
remoteAddr net.Addr
portForwards []string
jwtUsername string
}
// authKey uniquely identifies an authentication attempt by username and remote address.
// Used to temporarily store JWT username between passwordHandler and sessionHandler.
type authKey string
// connKey uniquely identifies an SSH connection by its remote address.
// Used to track authenticated connections for status display and port forwarding.
type connKey string
func newAuthKey(username string, remoteAddr net.Addr) authKey {
return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String()))
}
// sessionState tracks an active SSH session (shell, command, or subsystem like SFTP).
type sessionState struct {
session ssh.Session
sessionType string
jwtUsername string
}
type Server struct {
sshServer *ssh.Server
mu sync.RWMutex
hostKeyPEM []byte
sessions map[SessionKey]ssh.Session
sessionCancels map[ConnectionKey]context.CancelFunc
sessionJWTUsers map[SessionKey]string
pendingAuthJWT map[authKey]string
sshServer *ssh.Server
mu sync.RWMutex
hostKeyPEM []byte
// sessions tracks active SSH sessions (shell, command, SFTP).
// These are created when a client opens a session channel and requests shell/exec/subsystem.
sessions map[sessionKey]*sessionState
// pendingAuthJWT temporarily stores JWT username during the auth→session handoff.
// Populated in passwordHandler, consumed in sessionHandler/sftpSubsystemHandler.
pendingAuthJWT map[authKey]string
// connections tracks all SSH connections by their remote address.
// Populated at authentication time, stores JWT username and port forwards for status display.
connections map[connKey]*connState
allowLocalPortForwarding bool
allowRemotePortForwarding bool
@@ -132,8 +162,7 @@ type Server struct {
wgAddress wgaddr.Address
remoteForwardListeners map[ForwardKey]net.Listener
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
remoteForwardListeners map[forwardKey]net.Listener
jwtValidator *jwt.Validator
jwtExtractor *jwt.ClaimsExtractor
@@ -167,6 +196,7 @@ type SessionInfo struct {
RemoteAddress string
Command string
JWTUsername string
PortForwards []string
}
// New creates an SSH server instance with the provided host key and optional JWT configuration
@@ -175,11 +205,10 @@ func New(config *Config) *Server {
s := &Server{
mu: sync.RWMutex{},
hostKeyPEM: config.HostKeyPEM,
sessions: make(map[SessionKey]ssh.Session),
sessionJWTUsers: make(map[SessionKey]string),
sessions: make(map[sessionKey]*sessionState),
pendingAuthJWT: make(map[authKey]string),
remoteForwardListeners: make(map[ForwardKey]net.Listener),
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
remoteForwardListeners: make(map[forwardKey]net.Listener),
connections: make(map[connKey]*connState),
jwtEnabled: config.JWT != nil,
jwtConfig: config.JWT,
authorizer: sshauth.NewAuthorizer(), // Initialize with empty config
@@ -265,14 +294,8 @@ func (s *Server) Stop() error {
s.sshServer = nil
maps.Clear(s.sessions)
maps.Clear(s.sessionJWTUsers)
maps.Clear(s.pendingAuthJWT)
maps.Clear(s.sshConnections)
for _, cancelFunc := range s.sessionCancels {
cancelFunc()
}
maps.Clear(s.sessionCancels)
maps.Clear(s.connections)
for _, listener := range s.remoteForwardListeners {
if err := listener.Close(); err != nil {
@@ -284,32 +307,70 @@ func (s *Server) Stop() error {
return nil
}
// GetStatus returns the current status of the SSH server and active sessions
// GetStatus returns the current status of the SSH server and active sessions.
func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
s.mu.RLock()
defer s.mu.RUnlock()
enabled = s.sshServer != nil
reportedAddrs := make(map[string]bool)
for sessionKey, session := range s.sessions {
cmd := "<interactive shell>"
if len(session.Command()) > 0 {
cmd = safeLogCommand(session.Command())
for _, state := range s.sessions {
info := s.buildSessionInfo(state)
reportedAddrs[info.RemoteAddress] = true
sessions = append(sessions, info)
}
// Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only)
for key, connState := range s.connections {
remoteAddr := string(key)
if reportedAddrs[remoteAddr] {
continue
}
cmd := cmdNonInteractive
if len(connState.portForwards) > 0 {
cmd = cmdPortForwarding
}
jwtUsername := s.sessionJWTUsers[sessionKey]
sessions = append(sessions, SessionInfo{
Username: session.User(),
RemoteAddress: session.RemoteAddr().String(),
Username: connState.username,
RemoteAddress: remoteAddr,
Command: cmd,
JWTUsername: jwtUsername,
JWTUsername: connState.jwtUsername,
PortForwards: connState.portForwards,
})
}
return enabled, sessions
}
func (s *Server) buildSessionInfo(state *sessionState) SessionInfo {
session := state.session
cmd := state.sessionType
if cmd == "" {
cmd = safeLogCommand(session.Command())
}
remoteAddr := session.RemoteAddr().String()
info := SessionInfo{
Username: session.User(),
RemoteAddress: remoteAddr,
Command: cmd,
JWTUsername: state.jwtUsername,
}
connState, exists := s.connections[connKey(remoteAddr)]
if !exists {
return info
}
info.PortForwards = connState.portForwards
if len(connState.portForwards) > 0 && (cmd == cmdInteractiveShell || cmd == cmdNonInteractive) {
info.Command = cmdPortForwarding
}
return info
}
// SetNetstackNet sets the netstack network for userspace networking
func (s *Server) SetNetstackNet(net *netstack.Net) {
s.mu.Lock()
@@ -520,69 +581,129 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
osUsername := ctx.User()
remoteAddr := ctx.RemoteAddr()
logger := s.getRequestLogger(ctx)
if err := s.ensureJWTValidator(); err != nil {
log.Errorf("JWT validator initialization failed for user %s from %s: %v", osUsername, remoteAddr, err)
logger.Errorf("JWT validator initialization failed: %v", err)
return false
}
token, err := s.validateJWTToken(password)
if err != nil {
log.Warnf("JWT authentication failed for user %s from %s: %v", osUsername, remoteAddr, err)
logger.Warnf("JWT authentication failed: %v", err)
return false
}
userAuth, err := s.extractAndValidateUser(token)
if err != nil {
log.Warnf("User validation failed for user %s from %s: %v", osUsername, remoteAddr, err)
logger.Warnf("user validation failed: %v", err)
return false
}
logger = logger.WithField("jwt_user", userAuth.UserId)
s.mu.RLock()
authorizer := s.authorizer
s.mu.RUnlock()
if err := authorizer.Authorize(userAuth.UserId, osUsername); err != nil {
log.Warnf("SSH authorization denied for user %s (JWT user ID: %s) from %s: %v", osUsername, userAuth.UserId, remoteAddr, err)
msg, err := authorizer.Authorize(userAuth.UserId, osUsername)
if err != nil {
logger.Warnf("SSH auth denied: %v", err)
return false
}
logger.Infof("SSH auth %s", msg)
key := newAuthKey(osUsername, remoteAddr)
remoteAddrStr := ctx.RemoteAddr().String()
s.mu.Lock()
s.pendingAuthJWT[key] = userAuth.UserId
s.connections[connKey(remoteAddrStr)] = &connState{
username: ctx.User(),
remoteAddr: ctx.RemoteAddr(),
jwtUsername: userAuth.UserId,
}
s.mu.Unlock()
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", osUsername, userAuth.UserId, remoteAddr)
return true
}
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
func (s *Server) addConnectionPortForward(username string, remoteAddr net.Addr, forwardAddr string) {
s.mu.Lock()
defer s.mu.Unlock()
if state, exists := s.sshConnections[sshConn]; exists {
state.hasActivePortForward = true
} else {
s.sshConnections[sshConn] = &sshConnectionState{
hasActivePortForward: true,
username: username,
remoteAddr: remoteAddr,
key := connKey(remoteAddr.String())
if state, exists := s.connections[key]; exists {
if !slices.Contains(state.portForwards, forwardAddr) {
state.portForwards = append(state.portForwards, forwardAddr)
}
return
}
// Connection not in connections (non-JWT auth path)
s.connections[key] = &connState{
username: username,
remoteAddr: remoteAddr,
portForwards: []string{forwardAddr},
jwtUsername: s.pendingAuthJWT[newAuthKey(username, remoteAddr)],
}
}
func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
// We can't extract the SSH connection from net.Conn directly
// Connection cleanup will happen during session cleanup or via timeout
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
func (s *Server) removeConnectionPortForward(remoteAddr net.Addr, forwardAddr string) {
s.mu.Lock()
defer s.mu.Unlock()
state, exists := s.connections[connKey(remoteAddr.String())]
if !exists {
return
}
state.portForwards = slices.DeleteFunc(state.portForwards, func(addr string) bool {
return addr == forwardAddr
})
}
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
// trackedConn wraps a net.Conn to detect when it closes
type trackedConn struct {
net.Conn
server *Server
remoteAddr string
onceClose sync.Once
}
func (c *trackedConn) Close() error {
err := c.Conn.Close()
c.onceClose.Do(func() {
c.server.handleConnectionClose(c.remoteAddr)
})
return err
}
func (s *Server) handleConnectionClose(remoteAddr string) {
s.mu.Lock()
defer s.mu.Unlock()
key := connKey(remoteAddr)
state, exists := s.connections[key]
if exists && len(state.portForwards) > 0 {
s.connLogger(state).Info("port forwarding connection closed")
}
delete(s.connections, key)
}
func (s *Server) connLogger(state *connState) *log.Entry {
logger := log.WithField("session", fmt.Sprintf("%s@%s", state.username, state.remoteAddr))
if state.jwtUsername != "" {
logger = logger.WithField("jwt_user", state.jwtUsername)
}
return logger
}
func (s *Server) findSessionKeyByContext(ctx ssh.Context) sessionKey {
if ctx == nil {
return "unknown"
}
// Try to match by SSH connection
sshConn := ctx.Value(ssh.ContextKeyConn)
if sshConn == nil {
return "unknown"
@@ -591,19 +712,14 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
s.mu.RLock()
defer s.mu.RUnlock()
// Look through sessions to find one with matching connection
for sessionKey, session := range s.sessions {
if session.Context().Value(ssh.ContextKeyConn) == sshConn {
for sessionKey, state := range s.sessions {
if state.session.Context().Value(ssh.ContextKeyConn) == sshConn {
return sessionKey
}
}
// If no session found, this might be during early connection setup
// Return a temporary key that we'll fix up later
if ctx.User() != "" && ctx.RemoteAddr() != nil {
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
return tempKey
return sessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
}
return "unknown"
@@ -644,7 +760,11 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
}
log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr)
return conn
return &trackedConn{
Conn: conn,
server: s,
remoteAddr: conn.RemoteAddr().String(),
}
}
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
@@ -672,9 +792,8 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
"tcpip-forward": s.tcpipForwardHandler,
"cancel-tcpip-forward": s.cancelTcpipForwardHandler,
},
ConnCallback: s.connectionValidator,
ConnectionFailedCallback: s.connectionCloseHandler,
Version: serverVersion,
ConnCallback: s.connectionValidator,
Version: serverVersion,
}
if s.jwtEnabled {
@@ -690,13 +809,13 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
return server, nil
}
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
func (s *Server) storeRemoteForwardListener(key forwardKey, ln net.Listener) {
s.mu.Lock()
defer s.mu.Unlock()
s.remoteForwardListeners[key] = ln
}
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
func (s *Server) removeRemoteForwardListener(key forwardKey) bool {
s.mu.Lock()
defer s.mu.Unlock()
@@ -714,6 +833,8 @@ func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
}
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
logger := s.getRequestLogger(ctx)
var payload struct {
Host string
Port uint32
@@ -723,7 +844,7 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil {
log.Debugf("channel reject error: %v", err)
logger.Debugf("channel reject error: %v", err)
}
return
}
@@ -733,19 +854,20 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
s.mu.RUnlock()
if !allowLocal {
log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port)
logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port)
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
return
}
// Check privilege requirements for the destination port
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
return
}
log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port)
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
}

View File

@@ -224,6 +224,96 @@ func TestServer_PortForwardingRestriction(t *testing.T) {
}
}
func TestServer_PrivilegedPortAccess(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
}
server := New(serverConfig)
server.SetAllowRemotePortForwarding(true)
tests := []struct {
name string
forwardType string
port uint32
username string
expectError bool
errorMsg string
skipOnWindows bool
}{
{
name: "non-root user remote forward privileged port",
forwardType: "remote",
port: 80,
username: "testuser",
expectError: true,
errorMsg: "cannot bind to privileged port",
skipOnWindows: true,
},
{
name: "non-root user tcpip-forward privileged port",
forwardType: "tcpip-forward",
port: 443,
username: "testuser",
expectError: true,
errorMsg: "cannot bind to privileged port",
skipOnWindows: true,
},
{
name: "non-root user remote forward unprivileged port",
forwardType: "remote",
port: 8080,
username: "testuser",
expectError: false,
},
{
name: "non-root user remote forward port 0",
forwardType: "remote",
port: 0,
username: "testuser",
expectError: false,
},
{
name: "root user remote forward privileged port",
forwardType: "remote",
port: 22,
username: "root",
expectError: false,
},
{
name: "local forward privileged port allowed for non-root",
forwardType: "local",
port: 80,
username: "testuser",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.skipOnWindows && runtime.GOOS == "windows" {
t.Skip("Windows does not have privileged port restrictions")
}
result := PrivilegeCheckResult{
Allowed: true,
User: &user.User{Username: tt.username},
}
err := server.checkPrivilegedPortAccess(tt.forwardType, tt.port, result)
if tt.expectError {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
require.NoError(t, err)
}
})
}
}
func TestServer_PortConflictHandling(t *testing.T) {
// Test that multiple sessions requesting the same local port are handled naturally by the OS
// Get current user for SSH connection
@@ -392,3 +482,95 @@ func TestServer_IsPrivilegedUser(t *testing.T) {
})
}
}
func TestServer_PortForwardingOnlySession(t *testing.T) {
// Test that sessions without PTY and command are allowed when port forwarding is enabled
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
tests := []struct {
name string
allowLocalForwarding bool
allowRemoteForwarding bool
expectAllowed bool
description string
}{
{
name: "session_allowed_with_local_forwarding",
allowLocalForwarding: true,
allowRemoteForwarding: false,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when local forwarding is enabled",
},
{
name: "session_allowed_with_remote_forwarding",
allowLocalForwarding: false,
allowRemoteForwarding: true,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when remote forwarding is enabled",
},
{
name: "session_allowed_with_both",
allowLocalForwarding: true,
allowRemoteForwarding: true,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when both forwarding types enabled",
},
{
name: "session_denied_without_forwarding",
allowLocalForwarding: false,
allowRemoteForwarding: false,
expectAllowed: false,
description: "Port-forwarding-only session should be denied when all forwarding is disabled",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
serverAddr := StartTestServer(t, server)
defer func() {
_ = server.Stop()
}()
// Connect to the server without requesting PTY or command
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
client, err := sshclient.Dial(ctx, serverAddr, currentUser.Username, sshclient.DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
_ = client.Close()
}()
// Execute a command without PTY - this simulates ssh -T with no command
// The server should either allow it (port forwarding enabled) or reject it
output, err := client.ExecuteCommand(ctx, "")
if tt.expectAllowed {
// When allowed, the session stays open until cancelled
// ExecuteCommand with empty command should return without error
assert.NoError(t, err, "Session should be allowed when port forwarding is enabled")
assert.NotContains(t, output, "port forwarding is disabled",
"Output should not contain port forwarding disabled message")
} else if err != nil {
// When denied, we expect an error message about port forwarding being disabled
assert.Contains(t, err.Error(), "port forwarding is disabled",
"Should get port forwarding disabled message")
}
})
}
}

View File

@@ -6,37 +6,45 @@ import (
"errors"
"fmt"
"io"
"strings"
"time"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
)
// associateJWTUsername extracts pending JWT username for the session and associates it with the session state.
// Returns the JWT username (empty if none) for logging purposes.
func (s *Server) associateJWTUsername(sess ssh.Session, sessionKey sessionKey) string {
key := newAuthKey(sess.User(), sess.RemoteAddr())
s.mu.Lock()
defer s.mu.Unlock()
jwtUsername := s.pendingAuthJWT[key]
if jwtUsername == "" {
return ""
}
if state, exists := s.sessions[sessionKey]; exists {
state.jwtUsername = jwtUsername
}
delete(s.pendingAuthJWT, key)
return jwtUsername
}
// sessionHandler handles SSH sessions
func (s *Server) sessionHandler(session ssh.Session) {
sessionKey := s.registerSession(session)
key := newAuthKey(session.User(), session.RemoteAddr())
s.mu.Lock()
jwtUsername := s.pendingAuthJWT[key]
if jwtUsername != "" {
s.sessionJWTUsers[sessionKey] = jwtUsername
delete(s.pendingAuthJWT, key)
}
s.mu.Unlock()
sessionKey := s.registerSession(session, "")
jwtUsername := s.associateJWTUsername(session, sessionKey)
logger := log.WithField("session", sessionKey)
if jwtUsername != "" {
logger = logger.WithField("jwt_user", jwtUsername)
logger.Infof("SSH session started (JWT user: %s)", jwtUsername)
} else {
logger.Infof("SSH session started")
}
logger.Info("SSH session started")
sessionStart := time.Now()
defer s.unregisterSession(sessionKey, session)
defer s.unregisterSession(sessionKey)
defer func() {
duration := time.Since(sessionStart).Round(time.Millisecond)
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
@@ -65,27 +73,52 @@ func (s *Server) sessionHandler(session ssh.Session) {
// ssh <host> <cmd> - non-Pty command execution
s.handleCommand(logger, session, privilegeResult, nil)
default:
s.rejectInvalidSession(logger, session)
// ssh -T (or ssh -N) - no PTY, no command
s.handleNonInteractiveSession(logger, session)
}
}
func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) {
if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil {
logger.Debugf(errWriteSession, err)
// handleNonInteractiveSession handles sessions that have no PTY and no command.
// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N).
func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) {
s.updateSessionType(session, cmdNonInteractive)
if !s.isPortForwardingEnabled() {
if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil {
logger.Debugf(errWriteSession, err)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
logger.Infof("rejected non-interactive session: port forwarding disabled")
return
}
if err := session.Exit(1); err != nil {
<-session.Context().Done()
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr())
}
func (s *Server) registerSession(session ssh.Session) SessionKey {
func (s *Server) updateSessionType(session ssh.Session, sessionType string) {
s.mu.Lock()
defer s.mu.Unlock()
for _, state := range s.sessions {
if state.session == session {
state.sessionType = sessionType
return
}
}
}
func (s *Server) registerSession(session ssh.Session, sessionType string) sessionKey {
sessionID := session.Context().Value(ssh.ContextKeySessionID)
if sessionID == nil {
sessionID = fmt.Sprintf("%p", session)
}
// Create a short 4-byte identifier from the full session ID
hasher := sha256.New()
hasher.Write([]byte(fmt.Sprintf("%v", sessionID)))
hash := hasher.Sum(nil)
@@ -93,43 +126,23 @@ func (s *Server) registerSession(session ssh.Session) SessionKey {
remoteAddr := session.RemoteAddr().String()
username := session.User()
sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID))
sessionKey := sessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID))
s.mu.Lock()
s.sessions[sessionKey] = session
s.sessions[sessionKey] = &sessionState{
session: session,
sessionType: sessionType,
}
s.mu.Unlock()
return sessionKey
}
func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) {
func (s *Server) unregisterSession(sessionKey sessionKey) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionKey)
delete(s.sessionJWTUsers, sessionKey)
// Cancel all port forwarding connections for this session
var connectionsToCancel []ConnectionKey
for key := range s.sessionCancels {
if strings.HasPrefix(string(key), string(sessionKey)+"-") {
connectionsToCancel = append(connectionsToCancel, key)
}
}
for _, key := range connectionsToCancel {
if cancelFunc, exists := s.sessionCancels[key]; exists {
log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key)
cancelFunc()
delete(s.sessionCancels, key)
}
}
if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil {
if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok {
delete(s.sshConnections, sshConn)
}
}
s.mu.Unlock()
}
func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) {

View File

@@ -18,14 +18,26 @@ func (s *Server) SetAllowSFTP(allow bool) {
// sftpSubsystemHandler handles SFTP subsystem requests
func (s *Server) sftpSubsystemHandler(sess ssh.Session) {
sessionKey := s.registerSession(sess, cmdSFTP)
defer s.unregisterSession(sessionKey)
jwtUsername := s.associateJWTUsername(sess, sessionKey)
logger := log.WithField("session", sessionKey)
if jwtUsername != "" {
logger = logger.WithField("jwt_user", jwtUsername)
}
logger.Info("SFTP session started")
defer logger.Info("SFTP session closed")
s.mu.RLock()
allowSFTP := s.allowSFTP
s.mu.RUnlock()
if !allowSFTP {
log.Debugf("SFTP subsystem request denied: SFTP disabled")
logger.Debug("SFTP subsystem request denied: SFTP disabled")
if err := sess.Exit(1); err != nil {
log.Debugf("SFTP session exit failed: %v", err)
logger.Debugf("SFTP session exit: %v", err)
}
return
}
@@ -37,31 +49,27 @@ func (s *Server) sftpSubsystemHandler(sess ssh.Session) {
})
if !result.Allowed {
log.Warnf("SFTP access denied for user %s from %s: %v", sess.User(), sess.RemoteAddr(), result.Error)
logger.Warnf("SFTP access denied: %v", result.Error)
if err := sess.Exit(1); err != nil {
log.Debugf("exit SFTP session: %v", err)
logger.Debugf("exit SFTP session: %v", err)
}
return
}
log.Debugf("SFTP subsystem request from user %s (effective user %s)", sess.User(), result.User.Username)
if !result.RequiresUserSwitching {
if err := s.executeSftpDirect(sess); err != nil {
log.Errorf("SFTP direct execution: %v", err)
logger.Errorf("SFTP direct execution: %v", err)
}
return
}
if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil {
log.Errorf("SFTP privilege drop execution: %v", err)
logger.Errorf("SFTP privilege drop execution: %v", err)
}
}
// executeSftpDirect executes SFTP directly without privilege dropping
func (s *Server) executeSftpDirect(sess ssh.Session) error {
log.Debugf("starting SFTP session for user %s (no privilege dropping)", sess.User())
sftpServer, err := sftp.NewServer(sess)
if err != nil {
return fmt.Errorf("SFTP server creation: %w", err)