diff --git a/client/internal/engine_ssh.go b/client/internal/engine_ssh.go index 1419bc262..9ef70bf6e 100644 --- a/client/internal/engine_ssh.go +++ b/client/internal/engine_ssh.go @@ -41,6 +41,14 @@ func (e *Engine) setupSSHPortRedirection() error { } log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr) + if v6 := e.wgInterface.Address().IPv6; v6.IsValid() { + if err := e.firewall.AddInboundDNAT(v6, firewallManager.ProtocolTCP, 22, 22022); err != nil { + log.Warnf("failed to add IPv6 SSH port redirection: %v", err) + } else { + log.Infof("SSH port redirection enabled: [%s]:22 -> [%s]:22022", v6, v6) + } + } + return nil } @@ -137,12 +145,13 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) [] continue } - peerIP := e.extractPeerIP(peerConfig) + peerIP, peerIPv6 := e.extractPeerIPs(peerConfig) hostname := e.extractHostname(peerConfig) peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{ Hostname: hostname, IP: peerIP, + IPv6: peerIPv6, FQDN: peerConfig.GetFqdn(), }) } @@ -150,16 +159,26 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) [] return peerInfo } -// extractPeerIP extracts IP address from peer's allowed IPs -func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string { - if len(peerConfig.GetAllowedIps()) == 0 { - return "" +// extractPeerIPs extracts IPv4 and IPv6 overlay addresses from peer's allowed IPs. +// Only considers host routes (/32, /128) within the overlay networks to avoid +// picking up routed prefixes or static routes like 2620:fe::fe/128. +func (e *Engine) extractPeerIPs(peerConfig *mgmProto.RemotePeerConfig) (v4, v6 netip.Addr) { + wgAddr := e.wgInterface.Address() + for _, allowedIP := range peerConfig.GetAllowedIps() { + prefix, err := netip.ParsePrefix(allowedIP) + if err != nil { + log.Warnf("failed to parse AllowedIP %q: %v", allowedIP, err) + continue + } + addr := prefix.Addr().Unmap() + switch { + case addr.Is4() && prefix.Bits() == 32 && wgAddr.Network.Contains(addr) && !v4.IsValid(): + v4 = addr + case addr.Is6() && prefix.Bits() == 128 && wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(addr) && !v6.IsValid(): + v6 = addr + } } - - if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil { - return prefix.Addr().String() - } - return "" + return v4, v6 } // extractHostname extracts short hostname from FQDN @@ -208,7 +227,7 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) { fullStatus := statusRecorder.GetFullStatus() for _, peerState := range fullStatus.Peers { - if peerState.IP == peerAddress || peerState.FQDN == peerAddress { + if peerState.IP == peerAddress || peerState.FQDN == peerAddress || peerState.IPv6 == peerAddress { if len(peerState.SSHHostKey) > 0 { return peerState.SSHHostKey, true } @@ -262,6 +281,13 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error { return fmt.Errorf("start SSH server: %w", err) } + if v6 := wgAddr.IPv6; v6.IsValid() { + v6Addr := netip.AddrPortFrom(v6, sshserver.InternalSSHPort) + if err := server.AddListener(e.ctx, v6Addr); err != nil { + log.Warnf("failed to add IPv6 SSH listener: %v", err) + } + } + e.sshServer = server if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { @@ -330,6 +356,12 @@ func (e *Engine) cleanupSSHPortRedirection() error { } log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr) + if v6 := e.wgInterface.Address().IPv6; v6.IsValid() { + if err := e.firewall.RemoveInboundDNAT(v6, firewallManager.ProtocolTCP, 22, 22022); err != nil { + log.Debugf("failed to remove IPv6 SSH port redirection: %v", err) + } + } + return nil } diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index a4ffa3a25..084c642c2 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -188,7 +188,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) { case nftypes.TCP, nftypes.UDP, nftypes.SCTP: srcPort = flow.TupleOrig.Proto.SourcePort dstPort = flow.TupleOrig.Proto.DestinationPort - case nftypes.ICMP: + case nftypes.ICMP, nftypes.ICMPv6: icmpType = flow.TupleOrig.Proto.ICMPType icmpCode = flow.TupleOrig.Proto.ICMPCode } @@ -231,8 +231,14 @@ func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool { } // fallback if mark rules are not in place - wgnet := c.iface.Address().Network - return wgnet.Contains(srcIP) || wgnet.Contains(dstIP) + addr := c.iface.Address() + if addr.Network.Contains(srcIP) || addr.Network.Contains(dstIP) { + return true + } + if addr.IPv6Net.IsValid() { + return addr.IPv6Net.Contains(srcIP) || addr.IPv6Net.Contains(dstIP) + } + return false } // mapRxPackets maps packet counts to RX based on flow direction @@ -291,17 +297,16 @@ func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes } // fallback if marks are not set - wgaddr := c.iface.Address().IP - wgnetwork := c.iface.Address().Network + addr := c.iface.Address() switch { - case wgaddr == srcIP: + case addr.IP == srcIP || (addr.IPv6.IsValid() && addr.IPv6 == srcIP): return nftypes.Egress - case wgaddr == dstIP: + case addr.IP == dstIP || (addr.IPv6.IsValid() && addr.IPv6 == dstIP): return nftypes.Ingress - case wgnetwork.Contains(srcIP): + case addr.Network.Contains(srcIP) || (addr.IPv6Net.IsValid() && addr.IPv6Net.Contains(srcIP)): // netbird network -> resource network return nftypes.Ingress - case wgnetwork.Contains(dstIP): + case addr.Network.Contains(dstIP) || (addr.IPv6Net.IsValid() && addr.IPv6Net.Contains(dstIP)): // resource network -> netbird network return nftypes.Egress } diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index a033a2a7c..8f8e68784 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -24,15 +24,17 @@ type Logger struct { cancel context.CancelFunc statusRecorder *peer.Status wgIfaceNet netip.Prefix + wgIfaceNetV6 netip.Prefix dnsCollection atomic.Bool exitNodeCollection atomic.Bool Store types.Store } -func New(statusRecorder *peer.Status, wgIfaceIPNet netip.Prefix) *Logger { +func New(statusRecorder *peer.Status, wgIfaceIPNet, wgIfaceIPNetV6 netip.Prefix) *Logger { return &Logger{ statusRecorder: statusRecorder, wgIfaceNet: wgIfaceIPNet, + wgIfaceNetV6: wgIfaceIPNetV6, Store: store.NewMemoryStore(), } } @@ -88,11 +90,11 @@ func (l *Logger) startReceiver() { var isSrcExitNode bool var isDestExitNode bool - if !l.wgIfaceNet.Contains(event.SourceIP) { + if !l.isOverlayIP(event.SourceIP) { event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP) } - if !l.wgIfaceNet.Contains(event.DestIP) { + if !l.isOverlayIP(event.DestIP) { event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP) } @@ -136,6 +138,10 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) { l.exitNodeCollection.Store(exitNodeCollection) } +func (l *Logger) isOverlayIP(ip netip.Addr) bool { + return l.wgIfaceNet.Contains(ip) || (l.wgIfaceNetV6.IsValid() && l.wgIfaceNetV6.Contains(ip)) +} + func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { // check dns collection if !l.dnsCollection.Load() && event.Protocol == types.UDP && diff --git a/client/internal/netflow/logger/logger_test.go b/client/internal/netflow/logger/logger_test.go index 1144544d8..ad2eedef2 100644 --- a/client/internal/netflow/logger/logger_test.go +++ b/client/internal/netflow/logger/logger_test.go @@ -12,7 +12,7 @@ import ( ) func TestStore(t *testing.T) { - logger := logger.New(nil, netip.Prefix{}) + logger := logger.New(nil, netip.Prefix{}, netip.Prefix{}) logger.Enable() event := types.EventFields{ diff --git a/client/internal/netflow/manager.go b/client/internal/netflow/manager.go index 7752c97b0..eff083dbf 100644 --- a/client/internal/netflow/manager.go +++ b/client/internal/netflow/manager.go @@ -35,11 +35,12 @@ type Manager struct { // NewManager creates a new netflow manager func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager { - var prefix netip.Prefix + var prefix, prefixV6 netip.Prefix if iface != nil { prefix = iface.Address().Network + prefixV6 = iface.Address().IPv6Net } - flowLogger := logger.New(statusRecorder, prefix) + flowLogger := logger.New(statusRecorder, prefix, prefixV6) var ct nftypes.ConnTracker if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() { @@ -269,7 +270,7 @@ func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent { }, } - if event.Protocol == nftypes.ICMP { + if event.Protocol == nftypes.ICMP || event.Protocol == nftypes.ICMPv6 { protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{ IcmpInfo: &proto.ICMPInfo{ IcmpType: uint32(event.ICMPType), diff --git a/client/internal/netflow/types/types.go b/client/internal/netflow/types/types.go index f76146ba3..3f7d0d0ad 100644 --- a/client/internal/netflow/types/types.go +++ b/client/internal/netflow/types/types.go @@ -19,6 +19,7 @@ const ( ICMP = Protocol(1) TCP = Protocol(6) UDP = Protocol(17) + ICMPv6 = Protocol(58) SCTP = Protocol(132) ) @@ -30,6 +31,8 @@ func (p Protocol) String() string { return "TCP" case 17: return "UDP" + case 58: + return "ICMPv6" case 132: return "SCTP" default: diff --git a/client/internal/rosenpass/manager.go b/client/internal/rosenpass/manager.go index 1faa22dc5..c69ea9a6c 100644 --- a/client/internal/rosenpass/manager.go +++ b/client/internal/rosenpass/manager.go @@ -75,7 +75,7 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar if err != nil { return fmt.Errorf("failed to parse rosenpass address: %w", err) } - peerAddr := fmt.Sprintf("%s:%s", wireGuardIP, strPort) + peerAddr := net.JoinHostPort(wireGuardIP, strPort) if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil { return fmt.Errorf("failed to resolve peer endpoint address: %w", err) } diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go index cc47fd2d2..c76f1a212 100644 --- a/client/ssh/config/manager.go +++ b/client/ssh/config/manager.go @@ -3,6 +3,7 @@ package config import ( "context" "fmt" + "net/netip" "os" "path/filepath" "runtime" @@ -91,7 +92,8 @@ type Manager struct { // PeerSSHInfo represents a peer's SSH configuration information type PeerSSHInfo struct { Hostname string - IP string + IP netip.Addr + IPv6 netip.Addr FQDN string } @@ -211,8 +213,11 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) { func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string { var hostPatterns []string - if peer.IP != "" { - hostPatterns = append(hostPatterns, peer.IP) + if peer.IP.IsValid() { + hostPatterns = append(hostPatterns, peer.IP.String()) + } + if peer.IPv6.IsValid() { + hostPatterns = append(hostPatterns, peer.IPv6.String()) } if peer.FQDN != "" { hostPatterns = append(hostPatterns, peer.FQDN) diff --git a/client/ssh/config/manager_test.go b/client/ssh/config/manager_test.go index dc3ad95b3..bf7b0d1c0 100644 --- a/client/ssh/config/manager_test.go +++ b/client/ssh/config/manager_test.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "net/netip" "os" "path/filepath" "runtime" @@ -28,12 +29,12 @@ func TestManager_SetupSSHClientConfig(t *testing.T) { peers := []PeerSSHInfo{ { Hostname: "peer1", - IP: "100.125.1.1", + IP: netip.MustParseAddr("100.125.1.1"), FQDN: "peer1.nb.internal", }, { Hostname: "peer2", - IP: "100.125.1.2", + IP: netip.MustParseAddr("100.125.1.2"), FQDN: "peer2.nb.internal", }, } @@ -101,7 +102,7 @@ func TestManager_PeerLimit(t *testing.T) { for i := 0; i < MaxPeersForSSHConfig+10; i++ { peers = append(peers, PeerSSHInfo{ Hostname: fmt.Sprintf("peer%d", i), - IP: fmt.Sprintf("100.125.1.%d", i%254+1), + IP: netip.MustParseAddr(fmt.Sprintf("100.125.1.%d", i%254+1)), FQDN: fmt.Sprintf("peer%d.nb.internal", i), }) } @@ -136,7 +137,7 @@ func TestManager_ForcedSSHConfig(t *testing.T) { for i := 0; i < MaxPeersForSSHConfig+10; i++ { peers = append(peers, PeerSSHInfo{ Hostname: fmt.Sprintf("peer%d", i), - IP: fmt.Sprintf("100.125.1.%d", i%254+1), + IP: netip.MustParseAddr(fmt.Sprintf("100.125.1.%d", i%254+1)), FQDN: fmt.Sprintf("peer%d.nb.internal", i), }) } diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index 4431ae423..0ad8ef127 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -137,10 +137,11 @@ type sessionState struct { } type Server struct { - sshServer *ssh.Server - listener net.Listener - mu sync.RWMutex - hostKeyPEM []byte + sshServer *ssh.Server + listener net.Listener + extraListeners []net.Listener + mu sync.RWMutex + hostKeyPEM []byte // sessions tracks active SSH sessions (shell, command, SFTP). // These are created when a client opens a session channel and requests shell/exec/subsystem. @@ -254,6 +255,35 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error { return nil } +// AddListener starts serving SSH on an additional address (e.g. IPv6). +// Must be called after Start. +func (s *Server) AddListener(ctx context.Context, addr netip.AddrPort) error { + s.mu.Lock() + srv := s.sshServer + if srv == nil { + s.mu.Unlock() + return errors.New("SSH server is not running") + } + + ln, addrDesc, err := s.createListener(ctx, addr) + if err != nil { + s.mu.Unlock() + return fmt.Errorf("create listener: %w", err) + } + + s.extraListeners = append(s.extraListeners, ln) + s.mu.Unlock() + + log.Infof("SSH server also listening on %s", addrDesc) + + go func() { + if err := srv.Serve(ln); err != nil && !errors.Is(err, ssh.ErrServerClosed) { + log.Errorf("SSH server error on %s: %v", addrDesc, err) + } + }() + return nil +} + func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) { if s.netstackNet != nil { ln, err := s.netstackNet.ListenTCPAddrPort(addr) @@ -294,6 +324,13 @@ func (s *Server) Stop() error { log.Debugf("close SSH server: %v", err) } + for _, ln := range s.extraListeners { + if err := ln.Close(); err != nil { + log.Debugf("close extra SSH listener: %v", err) + } + } + s.extraListeners = nil + s.sshServer = nil s.listener = nil @@ -746,11 +783,10 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) sessionKey { func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn { s.mu.RLock() - netbirdNetwork := s.wgAddress.Network - localIP := s.wgAddress.IP + wgAddr := s.wgAddress s.mu.RUnlock() - if !netbirdNetwork.IsValid() || !localIP.IsValid() { + if !wgAddr.Network.IsValid() || !wgAddr.IP.IsValid() { return conn } @@ -766,14 +802,17 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn { log.Warnf("SSH connection rejected: invalid remote IP %s", tcpAddr.IP) return nil } + remoteIP = remoteIP.Unmap() // Block connections from our own IP (prevent local apps from connecting to ourselves) - if remoteIP == localIP { + if remoteIP == wgAddr.IP || wgAddr.IPv6.IsValid() && remoteIP == wgAddr.IPv6 { log.Warnf("SSH connection rejected from own IP %s", remoteIP) return nil } - if !netbirdNetwork.Contains(remoteIP) { + inV4 := wgAddr.Network.Contains(remoteIP) + inV6 := wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(remoteIP) + if !inV4 && !inV6 { log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP) return nil }