mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-11 03:09:55 +00:00
[management, client] Add IPv6 overlay support (#5631)
This commit is contained in:
@@ -3,6 +3,7 @@ package config
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -91,7 +92,8 @@ type Manager struct {
|
||||
// PeerSSHInfo represents a peer's SSH configuration information
|
||||
type PeerSSHInfo struct {
|
||||
Hostname string
|
||||
IP string
|
||||
IP netip.Addr
|
||||
IPv6 netip.Addr
|
||||
FQDN string
|
||||
}
|
||||
|
||||
@@ -210,8 +212,11 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
|
||||
|
||||
func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
|
||||
var hostPatterns []string
|
||||
if peer.IP != "" {
|
||||
hostPatterns = append(hostPatterns, peer.IP)
|
||||
if peer.IP.IsValid() {
|
||||
hostPatterns = append(hostPatterns, peer.IP.String())
|
||||
}
|
||||
if peer.IPv6.IsValid() {
|
||||
hostPatterns = append(hostPatterns, peer.IPv6.String())
|
||||
}
|
||||
if peer.FQDN != "" {
|
||||
hostPatterns = append(hostPatterns, peer.FQDN)
|
||||
|
||||
@@ -2,6 +2,7 @@ package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -28,12 +29,12 @@ func TestManager_SetupSSHClientConfig(t *testing.T) {
|
||||
peers := []PeerSSHInfo{
|
||||
{
|
||||
Hostname: "peer1",
|
||||
IP: "100.125.1.1",
|
||||
IP: netip.MustParseAddr("100.125.1.1"),
|
||||
FQDN: "peer1.nb.internal",
|
||||
},
|
||||
{
|
||||
Hostname: "peer2",
|
||||
IP: "100.125.1.2",
|
||||
IP: netip.MustParseAddr("100.125.1.2"),
|
||||
FQDN: "peer2.nb.internal",
|
||||
},
|
||||
}
|
||||
@@ -101,7 +102,7 @@ func TestManager_PeerLimit(t *testing.T) {
|
||||
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
|
||||
peers = append(peers, PeerSSHInfo{
|
||||
Hostname: fmt.Sprintf("peer%d", i),
|
||||
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
|
||||
IP: netip.MustParseAddr(fmt.Sprintf("100.125.1.%d", i%254+1)),
|
||||
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
|
||||
})
|
||||
}
|
||||
@@ -127,8 +128,8 @@ func TestManager_MatchHostFormat(t *testing.T) {
|
||||
}
|
||||
|
||||
peers := []PeerSSHInfo{
|
||||
{Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"},
|
||||
{Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"},
|
||||
{Hostname: "peer1", IP: netip.MustParseAddr("100.125.1.1"), FQDN: "peer1.nb.internal"},
|
||||
{Hostname: "peer2", IP: netip.MustParseAddr("100.125.1.2"), FQDN: "peer2.nb.internal"},
|
||||
}
|
||||
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
@@ -167,7 +168,7 @@ func TestManager_ForcedSSHConfig(t *testing.T) {
|
||||
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
|
||||
peers = append(peers, PeerSSHInfo{
|
||||
Hostname: fmt.Sprintf("peer%d", i),
|
||||
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
|
||||
IP: netip.MustParseAddr(fmt.Sprintf("100.125.1.%d", i%254+1)),
|
||||
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -321,7 +321,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne
|
||||
return
|
||||
}
|
||||
|
||||
dest := fmt.Sprintf("%s:%d", payload.DestAddr, payload.DestPort)
|
||||
dest := net.JoinHostPort(payload.DestAddr, strconv.Itoa(int(payload.DestPort)))
|
||||
log.Debugf("local port forwarding: %s", dest)
|
||||
|
||||
backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())
|
||||
|
||||
@@ -56,12 +56,12 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
|
||||
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
|
||||
logger := s.getRequestLogger(ctx)
|
||||
if !allowLocal {
|
||||
logger.Warnf("local port forwarding denied for %s:%d: disabled", dstHost, dstPort)
|
||||
logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort))))
|
||||
return false
|
||||
}
|
||||
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
|
||||
logger.Warnf("local port forwarding denied for %s:%d: %v", dstHost, dstPort, err)
|
||||
logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort))), err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -71,12 +71,12 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
|
||||
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
|
||||
logger := s.getRequestLogger(ctx)
|
||||
if !allowRemote {
|
||||
logger.Warnf("remote port forwarding denied for %s:%d: disabled", bindHost, bindPort)
|
||||
logger.Warnf("remote port forwarding denied for %s: disabled", net.JoinHostPort(bindHost, strconv.Itoa(int(bindPort))))
|
||||
return false
|
||||
}
|
||||
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
|
||||
logger.Warnf("remote port forwarding denied for %s:%d: %v", bindHost, bindPort, err)
|
||||
logger.Warnf("remote port forwarding denied for %s: %v", net.JoinHostPort(bindHost, strconv.Itoa(int(bindPort))), err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -183,15 +183,16 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *
|
||||
return false, nil
|
||||
}
|
||||
|
||||
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
|
||||
hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))
|
||||
key := forwardKey(hostPort)
|
||||
if s.removeRemoteForwardListener(key) {
|
||||
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, payload.Port)
|
||||
forwardAddr := "-R " + hostPort
|
||||
s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr)
|
||||
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
|
||||
logger.Infof("remote port forwarding cancelled: %s", hostPort)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port)
|
||||
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -201,7 +202,7 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h
|
||||
|
||||
defer func() {
|
||||
if err := ln.Close(); err != nil {
|
||||
logger.Debugf("remote forward listener close error for %s:%d: %v", host, port, err)
|
||||
logger.Debugf("remote forward listener close error for %s: %v", net.JoinHostPort(host, strconv.Itoa(int(port))), err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -230,7 +231,7 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h
|
||||
}
|
||||
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
|
||||
case <-ctx.Done():
|
||||
logger.Debugf("remote forward listener shutting down for %s:%d", host, port)
|
||||
logger.Debugf("remote forward listener shutting down for %s", net.JoinHostPort(host, strconv.Itoa(int(port))))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -311,17 +312,17 @@ func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn
|
||||
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
|
||||
}
|
||||
|
||||
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
|
||||
key := forwardKey(net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
|
||||
s.storeRemoteForwardListener(key, ln)
|
||||
|
||||
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, actualPort)
|
||||
forwardAddr := "-R " + net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort)))
|
||||
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
|
||||
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
|
||||
|
||||
response := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(response, actualPort)
|
||||
|
||||
logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort)
|
||||
logger.Infof("remote port forwarding established: %s", net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort))))
|
||||
return true, response
|
||||
}
|
||||
|
||||
@@ -351,7 +352,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
|
||||
|
||||
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr)
|
||||
if err != nil {
|
||||
logger.Debugf("open forward channel for %s:%d: %v", host, port, err)
|
||||
logger.Debugf("open forward channel for %s: %v", net.JoinHostPort(host, strconv.Itoa(int(port))), err)
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -137,10 +138,11 @@ type sessionState struct {
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
sshServer *ssh.Server
|
||||
listener net.Listener
|
||||
mu sync.RWMutex
|
||||
hostKeyPEM []byte
|
||||
sshServer *ssh.Server
|
||||
listener net.Listener
|
||||
extraListeners []net.Listener
|
||||
mu sync.RWMutex
|
||||
hostKeyPEM []byte
|
||||
|
||||
// sessions tracks active SSH sessions (shell, command, SFTP).
|
||||
// These are created when a client opens a session channel and requests shell/exec/subsystem.
|
||||
@@ -254,6 +256,35 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddListener starts serving SSH on an additional address (e.g. IPv6).
|
||||
// Must be called after Start.
|
||||
func (s *Server) AddListener(ctx context.Context, addr netip.AddrPort) error {
|
||||
s.mu.Lock()
|
||||
srv := s.sshServer
|
||||
if srv == nil {
|
||||
s.mu.Unlock()
|
||||
return errors.New("SSH server is not running")
|
||||
}
|
||||
|
||||
ln, addrDesc, err := s.createListener(ctx, addr)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("create listener: %w", err)
|
||||
}
|
||||
|
||||
s.extraListeners = append(s.extraListeners, ln)
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Infof("SSH server also listening on %s", addrDesc)
|
||||
|
||||
go func() {
|
||||
if err := srv.Serve(ln); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
|
||||
log.Errorf("SSH server error on %s: %v", addrDesc, err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
|
||||
if s.netstackNet != nil {
|
||||
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
|
||||
@@ -291,6 +322,8 @@ func (s *Server) Stop() error {
|
||||
}
|
||||
s.sshServer = nil
|
||||
s.listener = nil
|
||||
extraListeners := s.extraListeners
|
||||
s.extraListeners = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
// Close outside the lock: session handlers need s.mu for unregisterSession.
|
||||
@@ -298,6 +331,12 @@ func (s *Server) Stop() error {
|
||||
log.Debugf("close SSH server: %v", err)
|
||||
}
|
||||
|
||||
for _, ln := range extraListeners {
|
||||
if err := ln.Close(); err != nil {
|
||||
log.Debugf("close extra SSH listener: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
maps.Clear(s.sessions)
|
||||
maps.Clear(s.pendingAuthJWT)
|
||||
@@ -749,11 +788,10 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) sessionKey {
|
||||
|
||||
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
||||
s.mu.RLock()
|
||||
netbirdNetwork := s.wgAddress.Network
|
||||
localIP := s.wgAddress.IP
|
||||
wgAddr := s.wgAddress
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !netbirdNetwork.IsValid() || !localIP.IsValid() {
|
||||
if !wgAddr.Network.IsValid() || !wgAddr.IP.IsValid() {
|
||||
return conn
|
||||
}
|
||||
|
||||
@@ -769,14 +807,17 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
||||
log.Warnf("SSH connection rejected: invalid remote IP %s", tcpAddr.IP)
|
||||
return nil
|
||||
}
|
||||
remoteIP = remoteIP.Unmap()
|
||||
|
||||
// Block connections from our own IP (prevent local apps from connecting to ourselves)
|
||||
if remoteIP == localIP {
|
||||
if remoteIP == wgAddr.IP || wgAddr.IPv6.IsValid() && remoteIP == wgAddr.IPv6 {
|
||||
log.Warnf("SSH connection rejected from own IP %s", remoteIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
if !netbirdNetwork.Contains(remoteIP) {
|
||||
inV4 := wgAddr.Network.Contains(remoteIP)
|
||||
inV6 := wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(remoteIP)
|
||||
if !inV4 && !inV6 {
|
||||
log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP)
|
||||
return nil
|
||||
}
|
||||
@@ -876,20 +917,21 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !allowLocal {
|
||||
logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port)
|
||||
logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
|
||||
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
|
||||
logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
|
||||
logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))), err)
|
||||
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
|
||||
return
|
||||
}
|
||||
|
||||
forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port)
|
||||
hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))
|
||||
forwardAddr := "-L " + hostPort
|
||||
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
|
||||
logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
|
||||
logger.Infof("local port forwarding: %s", hostPort)
|
||||
|
||||
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user