Merge remote-tracking branch 'origin/main' into improve-usp-fw

# Conflicts:
#	client/firewall/uspfilter/conntrack/common.go
#	client/firewall/uspfilter/filter.go
#	client/firewall/uspfilter/forwarder/icmp.go
#	client/firewall/uspfilter/forwarder/tcp.go
#	client/firewall/uspfilter/nat.go
This commit is contained in:
Viktor Liu
2026-05-07 12:29:30 +02:00
420 changed files with 25335 additions and 10765 deletions

View File

@@ -143,10 +143,11 @@ type sessionState struct {
}
type Server struct {
sshServer *ssh.Server
listener net.Listener
mu sync.RWMutex
hostKeyPEM []byte
sshServer *ssh.Server
listener net.Listener
extraListeners []net.Listener
mu sync.RWMutex
hostKeyPEM []byte
// sessions tracks active SSH sessions (shell, command, SFTP).
// These are created when a client opens a session channel and requests shell/exec/subsystem.
@@ -260,6 +261,35 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
return nil
}
// AddListener starts serving SSH on an additional address (e.g. IPv6).
// Must be called after Start.
func (s *Server) AddListener(ctx context.Context, addr netip.AddrPort) error {
s.mu.Lock()
srv := s.sshServer
if srv == nil {
s.mu.Unlock()
return errors.New("SSH server is not running")
}
ln, addrDesc, err := s.createListener(ctx, addr)
if err != nil {
s.mu.Unlock()
return fmt.Errorf("create listener: %w", err)
}
s.extraListeners = append(s.extraListeners, ln)
s.mu.Unlock()
log.Infof("SSH server also listening on %s", addrDesc)
go func() {
if err := srv.Serve(ln); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
log.Errorf("SSH server error on %s: %v", addrDesc, err)
}
}()
return nil
}
func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
if s.netstackNet != nil {
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
@@ -297,6 +327,8 @@ func (s *Server) Stop() error {
}
s.sshServer = nil
s.listener = nil
extraListeners := s.extraListeners
s.extraListeners = nil
s.mu.Unlock()
// Close outside the lock: session handlers need s.mu for unregisterSession.
@@ -304,6 +336,12 @@ func (s *Server) Stop() error {
log.Debugf("close SSH server: %v", err)
}
for _, ln := range extraListeners {
if err := ln.Close(); err != nil {
log.Debugf("close extra SSH listener: %v", err)
}
}
s.mu.Lock()
maps.Clear(s.sessions)
maps.Clear(s.pendingAuthJWT)
@@ -755,11 +793,10 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) sessionKey {
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
s.mu.RLock()
netbirdNetwork := s.wgAddress.Network
localIP := s.wgAddress.IP
wgAddr := s.wgAddress
s.mu.RUnlock()
if !netbirdNetwork.IsValid() || !localIP.IsValid() {
if !wgAddr.Network.IsValid() || !wgAddr.IP.IsValid() {
return conn
}
@@ -775,14 +812,17 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
log.Warnf("SSH connection rejected: invalid remote IP %s", tcpAddr.IP)
return nil
}
remoteIP = remoteIP.Unmap()
// Block connections from our own IP (prevent local apps from connecting to ourselves)
if remoteIP == localIP {
if remoteIP == wgAddr.IP || wgAddr.IPv6.IsValid() && remoteIP == wgAddr.IPv6 {
log.Warnf("SSH connection rejected from own IP %s", remoteIP)
return nil
}
if !netbirdNetwork.Contains(remoteIP) {
inV4 := wgAddr.Network.Contains(remoteIP)
inV6 := wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(remoteIP)
if !inV4 && !inV6 {
log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP)
return nil
}
@@ -882,20 +922,21 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
s.mu.RUnlock()
if !allowLocal {
logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port)
logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))))
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
return
}
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))), err)
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
return
}
forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port)
hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))
forwardAddr := "-L " + hostPort
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
logger.Infof("local port forwarding: %s", hostPort)
s.relayDirectTCPIP(ctx, newChan, payload.Host, int(payload.Port), logger)
}