mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
393 lines
9.8 KiB
Go
393 lines
9.8 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gliderlabs/ssh"
|
|
log "github.com/sirupsen/logrus"
|
|
cryptossh "golang.org/x/crypto/ssh"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
|
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
"github.com/netbirdio/netbird/client/proto"
|
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
|
"github.com/netbirdio/netbird/version"
|
|
)
|
|
|
|
const (
|
|
// sshConnectionTimeout is the timeout for SSH TCP connection establishment
|
|
sshConnectionTimeout = 120 * time.Second
|
|
// sshHandshakeTimeout is the timeout for SSH handshake completion
|
|
sshHandshakeTimeout = 30 * time.Second
|
|
|
|
jwtAuthErrorMsg = "JWT authentication: %w"
|
|
)
|
|
|
|
type SSHProxy struct {
|
|
daemonAddr string
|
|
targetHost string
|
|
targetPort int
|
|
stderr io.Writer
|
|
conn *grpc.ClientConn
|
|
daemonClient proto.DaemonServiceClient
|
|
}
|
|
|
|
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) {
|
|
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
|
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("connect to daemon: %w", err)
|
|
}
|
|
|
|
return &SSHProxy{
|
|
daemonAddr: daemonAddr,
|
|
targetHost: targetHost,
|
|
targetPort: targetPort,
|
|
stderr: stderr,
|
|
conn: grpcConn,
|
|
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
|
}, nil
|
|
}
|
|
|
|
func (p *SSHProxy) Close() error {
|
|
if p.conn != nil {
|
|
return p.conn.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *SSHProxy) Connect(ctx context.Context) error {
|
|
hint := profilemanager.GetLoginHint()
|
|
|
|
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint)
|
|
if err != nil {
|
|
return fmt.Errorf(jwtAuthErrorMsg, err)
|
|
}
|
|
|
|
return p.runProxySSHServer(ctx, jwtToken)
|
|
}
|
|
|
|
func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
|
|
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
|
|
|
|
sshServer := &ssh.Server{
|
|
Handler: func(s ssh.Session) {
|
|
p.handleSSHSession(ctx, s, jwtToken)
|
|
},
|
|
ChannelHandlers: map[string]ssh.ChannelHandler{
|
|
"session": ssh.DefaultSessionHandler,
|
|
"direct-tcpip": p.directTCPIPHandler,
|
|
},
|
|
SubsystemHandlers: map[string]ssh.SubsystemHandler{
|
|
"sftp": func(s ssh.Session) {
|
|
p.sftpSubsystemHandler(s, jwtToken)
|
|
},
|
|
},
|
|
RequestHandlers: map[string]ssh.RequestHandler{
|
|
"tcpip-forward": p.tcpipForwardHandler,
|
|
"cancel-tcpip-forward": p.cancelTcpipForwardHandler,
|
|
},
|
|
Version: serverVersion,
|
|
}
|
|
|
|
hostKey, err := generateHostKey()
|
|
if err != nil {
|
|
return fmt.Errorf("generate host key: %w", err)
|
|
}
|
|
sshServer.HostSigners = []ssh.Signer{hostKey}
|
|
|
|
conn := &stdioConn{
|
|
stdin: os.Stdin,
|
|
stdout: os.Stdout,
|
|
}
|
|
|
|
sshServer.HandleConn(conn)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
|
|
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
|
|
|
|
sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
|
|
if err != nil {
|
|
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
|
|
return
|
|
}
|
|
defer func() { _ = sshClient.Close() }()
|
|
|
|
serverSession, err := sshClient.NewSession()
|
|
if err != nil {
|
|
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
|
|
return
|
|
}
|
|
defer func() { _ = serverSession.Close() }()
|
|
|
|
serverSession.Stdin = session
|
|
serverSession.Stdout = session
|
|
serverSession.Stderr = session.Stderr()
|
|
|
|
ptyReq, winCh, isPty := session.Pty()
|
|
if isPty {
|
|
if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil {
|
|
log.Debugf("PTY request to backend: %v", err)
|
|
}
|
|
|
|
go func() {
|
|
for win := range winCh {
|
|
if err := serverSession.WindowChange(win.Height, win.Width); err != nil {
|
|
log.Debugf("window change: %v", err)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
if len(session.Command()) > 0 {
|
|
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
|
|
log.Debugf("run command: %v", err)
|
|
p.handleProxyExitCode(session, err)
|
|
}
|
|
return
|
|
}
|
|
|
|
if err = serverSession.Shell(); err != nil {
|
|
log.Debugf("start shell: %v", err)
|
|
return
|
|
}
|
|
if err := serverSession.Wait(); err != nil {
|
|
log.Debugf("session wait: %v", err)
|
|
p.handleProxyExitCode(session, err)
|
|
}
|
|
}
|
|
|
|
func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
|
|
var exitErr *cryptossh.ExitError
|
|
if errors.As(err, &exitErr) {
|
|
if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil {
|
|
log.Debugf("set exit status: %v", exitErr)
|
|
}
|
|
}
|
|
}
|
|
|
|
func generateHostKey() (ssh.Signer, error) {
|
|
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generate ED25519 key: %w", err)
|
|
}
|
|
|
|
signer, err := cryptossh.ParsePrivateKey(keyPEM)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse private key: %w", err)
|
|
}
|
|
|
|
return signer, nil
|
|
}
|
|
|
|
type stdioConn struct {
|
|
stdin io.Reader
|
|
stdout io.Writer
|
|
closed bool
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (c *stdioConn) Read(b []byte) (n int, err error) {
|
|
c.mu.Lock()
|
|
if c.closed {
|
|
c.mu.Unlock()
|
|
return 0, io.EOF
|
|
}
|
|
c.mu.Unlock()
|
|
return c.stdin.Read(b)
|
|
}
|
|
|
|
func (c *stdioConn) Write(b []byte) (n int, err error) {
|
|
c.mu.Lock()
|
|
if c.closed {
|
|
c.mu.Unlock()
|
|
return 0, io.ErrClosedPipe
|
|
}
|
|
c.mu.Unlock()
|
|
return c.stdout.Write(b)
|
|
}
|
|
|
|
func (c *stdioConn) Close() error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
c.closed = true
|
|
return nil
|
|
}
|
|
|
|
func (c *stdioConn) LocalAddr() net.Addr {
|
|
return &net.UnixAddr{Name: "stdio", Net: "unix"}
|
|
}
|
|
|
|
func (c *stdioConn) RemoteAddr() net.Addr {
|
|
return &net.UnixAddr{Name: "stdio", Net: "unix"}
|
|
}
|
|
|
|
func (c *stdioConn) SetDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (c *stdioConn) SetReadDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
|
|
_ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
|
|
}
|
|
|
|
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
|
|
ctx, cancel := context.WithCancel(s.Context())
|
|
defer cancel()
|
|
|
|
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
|
|
|
|
sshClient, err := p.dialBackend(ctx, targetAddr, s.User(), jwtToken)
|
|
if err != nil {
|
|
_, _ = fmt.Fprintf(s, "SSH connection failed: %v\n", err)
|
|
_ = s.Exit(1)
|
|
return
|
|
}
|
|
defer func() {
|
|
if err := sshClient.Close(); err != nil {
|
|
log.Debugf("close SSH client: %v", err)
|
|
}
|
|
}()
|
|
|
|
serverSession, err := sshClient.NewSession()
|
|
if err != nil {
|
|
_, _ = fmt.Fprintf(s, "create server session: %v\n", err)
|
|
_ = s.Exit(1)
|
|
return
|
|
}
|
|
defer func() {
|
|
if err := serverSession.Close(); err != nil {
|
|
log.Debugf("close server session: %v", err)
|
|
}
|
|
}()
|
|
|
|
stdin, stdout, err := p.setupSFTPPipes(serverSession)
|
|
if err != nil {
|
|
log.Debugf("setup SFTP pipes: %v", err)
|
|
_ = s.Exit(1)
|
|
return
|
|
}
|
|
|
|
if err := serverSession.RequestSubsystem("sftp"); err != nil {
|
|
_, _ = fmt.Fprintf(s, "SFTP subsystem request failed: %v\n", err)
|
|
_ = s.Exit(1)
|
|
return
|
|
}
|
|
|
|
p.runSFTPBridge(ctx, s, stdin, stdout, serverSession)
|
|
}
|
|
|
|
func (p *SSHProxy) setupSFTPPipes(serverSession *cryptossh.Session) (io.WriteCloser, io.Reader, error) {
|
|
stdin, err := serverSession.StdinPipe()
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("get stdin pipe: %w", err)
|
|
}
|
|
|
|
stdout, err := serverSession.StdoutPipe()
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("get stdout pipe: %w", err)
|
|
}
|
|
|
|
return stdin, stdout, nil
|
|
}
|
|
|
|
func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.WriteCloser, stdout io.Reader, serverSession *cryptossh.Session) {
|
|
copyErrCh := make(chan error, 2)
|
|
|
|
go func() {
|
|
_, err := io.Copy(stdin, s)
|
|
if err != nil {
|
|
log.Debugf("SFTP client to server copy: %v", err)
|
|
}
|
|
if err := stdin.Close(); err != nil {
|
|
log.Debugf("close stdin: %v", err)
|
|
}
|
|
copyErrCh <- err
|
|
}()
|
|
|
|
go func() {
|
|
_, err := io.Copy(s, stdout)
|
|
if err != nil {
|
|
log.Debugf("SFTP server to client copy: %v", err)
|
|
}
|
|
copyErrCh <- err
|
|
}()
|
|
|
|
go func() {
|
|
<-ctx.Done()
|
|
if err := serverSession.Close(); err != nil {
|
|
log.Debugf("force close server session on context cancellation: %v", err)
|
|
}
|
|
}()
|
|
|
|
for i := 0; i < 2; i++ {
|
|
if err := <-copyErrCh; err != nil && !errors.Is(err, io.EOF) {
|
|
log.Debugf("SFTP copy error: %v", err)
|
|
}
|
|
}
|
|
|
|
if err := serverSession.Wait(); err != nil {
|
|
log.Debugf("SFTP session ended: %v", err)
|
|
}
|
|
}
|
|
|
|
func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
|
|
return false, []byte("port forwarding not supported in proxy")
|
|
}
|
|
|
|
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
|
|
return true, nil
|
|
}
|
|
|
|
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
|
|
config := &cryptossh.ClientConfig{
|
|
User: user,
|
|
Auth: []cryptossh.AuthMethod{cryptossh.Password(jwtToken)},
|
|
Timeout: sshHandshakeTimeout,
|
|
HostKeyCallback: p.verifyHostKey,
|
|
}
|
|
|
|
dialer := &net.Dialer{
|
|
Timeout: sshConnectionTimeout,
|
|
}
|
|
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("connect to server: %w", err)
|
|
}
|
|
|
|
clientConn, chans, reqs, err := cryptossh.NewClientConn(conn, addr, config)
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
return nil, fmt.Errorf("SSH handshake: %w", err)
|
|
}
|
|
|
|
return cryptossh.NewClient(clientConn, chans, reqs), nil
|
|
}
|
|
|
|
func (p *SSHProxy) verifyHostKey(hostname string, remote net.Addr, key cryptossh.PublicKey) error {
|
|
verifier := nbssh.NewDaemonHostKeyVerifier(p.daemonClient)
|
|
callback := nbssh.CreateHostKeyCallback(verifier)
|
|
return callback(hostname, remote, key)
|
|
}
|