mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-14 04:39:54 +00:00
[client, proxy] Harden uspfilter conntrack and share TCP relay (#5936)
This commit is contained in:
@@ -25,6 +25,7 @@ import (
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -536,7 +537,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
|
||||
continue
|
||||
}
|
||||
|
||||
go c.handleLocalForward(localConn, remoteAddr)
|
||||
go c.handleLocalForward(ctx, localConn, remoteAddr)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -548,7 +549,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
|
||||
}
|
||||
|
||||
// handleLocalForward handles a single local port forwarding connection
|
||||
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
|
||||
func (c *Client) handleLocalForward(ctx context.Context, localConn net.Conn, remoteAddr string) {
|
||||
defer func() {
|
||||
if err := localConn.Close(); err != nil {
|
||||
log.Debugf("local port forwarding: close local connection: %v", err)
|
||||
@@ -571,7 +572,7 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
|
||||
}
|
||||
}()
|
||||
|
||||
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
|
||||
netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
|
||||
}
|
||||
|
||||
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
|
||||
@@ -653,16 +654,19 @@ func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr stri
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case newChan := <-channelRequests:
|
||||
case newChan, ok := <-channelRequests:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if newChan != nil {
|
||||
go c.handleRemoteForwardChannel(newChan, localAddr)
|
||||
go c.handleRemoteForwardChannel(ctx, newChan, localAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleRemoteForwardChannel handles a single forwarded-tcpip channel
|
||||
func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) {
|
||||
func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.NewChannel, localAddr string) {
|
||||
channel, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
@@ -675,8 +679,14 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
localConn, err := net.Dial("tcp", localAddr)
|
||||
// Bound the dial so a black-holed localAddr can't pin the accepted SSH
|
||||
// channel open indefinitely; the relay itself runs under the outer ctx.
|
||||
dialCtx, cancelDial := context.WithTimeout(ctx, 10*time.Second)
|
||||
var dialer net.Dialer
|
||||
localConn, err := dialer.DialContext(dialCtx, "tcp", localAddr)
|
||||
cancelDial()
|
||||
if err != nil {
|
||||
log.Debugf("remote port forwarding: dial %s: %v", localAddr, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -685,7 +695,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
|
||||
}
|
||||
}()
|
||||
|
||||
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
|
||||
netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
|
||||
}
|
||||
|
||||
// tcpipForwardMsg represents the structure for tcpip-forward requests
|
||||
|
||||
@@ -194,63 +194,3 @@ func buildAddressList(hostname string, remote net.Addr) []string {
|
||||
return addresses
|
||||
}
|
||||
|
||||
// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections.
|
||||
// It waits for both directions to complete before returning.
|
||||
// The caller is responsible for closing the connections.
|
||||
func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) {
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) {
|
||||
logger.Debugf("copy error (1->2): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) {
|
||||
logger.Debugf("copy error (2->1): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
<-done
|
||||
<-done
|
||||
}
|
||||
|
||||
func isExpectedCopyError(err error) bool {
|
||||
return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled)
|
||||
}
|
||||
|
||||
// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections.
|
||||
// It waits for both directions to complete or for context cancellation before returning.
|
||||
// Both connections are closed when the function returns.
|
||||
func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) {
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) {
|
||||
logger.Debugf("copy error (1->2): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) {
|
||||
logger.Debugf("copy error (2->1): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-done:
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
_ = conn1.Close()
|
||||
_ = conn2.Close()
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -352,7 +353,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne
|
||||
}
|
||||
go cryptossh.DiscardRequests(clientReqs)
|
||||
|
||||
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
|
||||
netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
|
||||
}
|
||||
|
||||
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
|
||||
@@ -591,7 +592,7 @@ func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh
|
||||
}
|
||||
go cryptossh.DiscardRequests(clientReqs)
|
||||
|
||||
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
|
||||
netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
|
||||
}
|
||||
|
||||
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
)
|
||||
|
||||
const privilegedPortThreshold = 1024
|
||||
@@ -357,7 +357,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
|
||||
return
|
||||
}
|
||||
|
||||
nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel)
|
||||
netrelay.Relay(ctx, conn, channel, netrelay.Options{Logger: logger})
|
||||
}
|
||||
|
||||
// openForwardChannel creates an SSH forwarded-tcpip channel
|
||||
|
||||
@@ -8,9 +8,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -53,6 +54,10 @@ const (
|
||||
DefaultJWTMaxTokenAge = 10 * 60
|
||||
)
|
||||
|
||||
// directTCPIPDialTimeout bounds how long relayDirectTCPIP waits on a dial to
|
||||
// the forwarded destination before rejecting the SSH channel.
|
||||
const directTCPIPDialTimeout = 30 * time.Second
|
||||
|
||||
var (
|
||||
ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled)
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
@@ -933,5 +938,29 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
|
||||
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
|
||||
logger.Infof("local port forwarding: %s", hostPort)
|
||||
|
||||
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
|
||||
s.relayDirectTCPIP(ctx, newChan, payload.Host, int(payload.Port), logger)
|
||||
}
|
||||
|
||||
// relayDirectTCPIP is a netrelay-based replacement for gliderlabs'
|
||||
// DirectTCPIPHandler. The upstream handler closes both sides on the first
|
||||
// EOF; netrelay.Relay propagates CloseWrite so each direction drains on its
|
||||
// own terms.
|
||||
func (s *Server) relayDirectTCPIP(ctx ssh.Context, newChan cryptossh.NewChannel, host string, port int, logger *log.Entry) {
|
||||
dest := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
|
||||
dialer := net.Dialer{Timeout: directTCPIPDialTimeout}
|
||||
dconn, err := dialer.DialContext(ctx, "tcp", dest)
|
||||
if err != nil {
|
||||
_ = newChan.Reject(cryptossh.ConnectionFailed, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
ch, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
_ = dconn.Close()
|
||||
return
|
||||
}
|
||||
go cryptossh.DiscardRequests(reqs)
|
||||
|
||||
netrelay.Relay(ctx, dconn, ch, netrelay.Options{Logger: logger})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user