diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 6c667c455..5b48932b4 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -194,17 +194,12 @@ func TestMain(m *testing.M) { } func TestEngine_SSH(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("skipping TestEngine_SSH") - } - key, err := wgtypes.GeneratePrivateKey() if err != nil { t.Fatal(err) return } - // Generate SSH key for the test sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) if err != nil { t.Fatal(err) diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index d76a70def..1e872f4a7 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -139,11 +139,6 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error { return fmt.Errorf("create listener: %w", err) } - if err := s.setupSocketFilter(ln); err != nil { - s.closeListener(ln) - return fmt.Errorf("setup socket filter: %w", err) - } - sshServer, err := s.createSSHServer(ln) if err != nil { s.cleanupOnError(ln) @@ -176,14 +171,6 @@ func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.L return ln, addr.String(), nil } -// setupSocketFilter attaches socket filter if needed -func (s *Server) setupSocketFilter(ln net.Listener) error { - if s.ifIdx == 0 || ln == nil || s.netstackNet != nil { - return nil - } - return attachSocketFilter(ln, s.ifIdx) -} - // closeListener safely closes a listener func (s *Server) closeListener(ln net.Listener) { if err := ln.Close(); err != nil { @@ -197,9 +184,6 @@ func (s *Server) cleanupOnError(ln net.Listener) { return } - if err := detachSocketFilter(ln); err != nil { - log.Errorf("failed to detach socket filter: %v", err) - } s.closeListener(ln) } @@ -218,13 +202,6 @@ func (s *Server) Stop() error { return nil } - if s.ifIdx > 0 && s.listener != nil { - if err := detachSocketFilter(s.listener); err != nil { - // without detaching the filter, the listener will block on shutdown - return fmt.Errorf("detach socket filter: %w", err) - } - } - if err := s.sshServer.Close(); err != nil && !isShutdownError(err) { return fmt.Errorf("shutdown SSH server: %w", err) } diff --git a/client/ssh/server/socket_filter_linux.go b/client/ssh/server/socket_filter_linux.go deleted file mode 100644 index 730317192..000000000 --- a/client/ssh/server/socket_filter_linux.go +++ /dev/null @@ -1,168 +0,0 @@ -//go:build linux - -package server - -import ( - "fmt" - "net" - "os" - "sync" - "syscall" - "unsafe" - - log "github.com/sirupsen/logrus" - "golang.org/x/net/bpf" - "golang.org/x/sys/unix" -) - -// SockFprog represents a BPF program for socket filtering -type SockFprog struct { - Len uint16 - Filter *unix.SockFilter -} - -// filterInfo stores the file descriptor and filter state for each listener -type filterInfo struct { - fd int - file *os.File -} - -var ( - listenerFilters = make(map[*net.TCPListener]*filterInfo) - filterMutex sync.RWMutex -) - -// attachSocketFilter attaches a BPF socket filter to restrict SSH connections -// to only the specified WireGuard interface index -func attachSocketFilter(listener net.Listener, wgIfIndex int) error { - tcpListener, ok := listener.(*net.TCPListener) - if !ok { - return fmt.Errorf("listener is not a TCP listener") - } - - file, err := tcpListener.File() - if err != nil { - return fmt.Errorf("get listener file descriptor: %w", err) - } - // Don't close the file here - we need it for detaching the filter - - // Set the duplicated FD to non-blocking to match the mode of the - // FD used by the Go runtime's network poller - if err := syscall.SetNonblock(int(file.Fd()), true); err != nil { - file.Close() - return fmt.Errorf("set non-blocking on duplicated FD: %w", err) - } - - // Create BPF program that filters by interface index - prog, err := createInterfaceFilterProgram(uint32(wgIfIndex)) - if err != nil { - file.Close() - return fmt.Errorf("create BPF program: %w", err) - } - - assembled, err := bpf.Assemble(prog) - if err != nil { - file.Close() - return fmt.Errorf("assemble BPF program: %w", err) - } - - // Convert to unix.SockFilter format - sockFilters := make([]unix.SockFilter, len(assembled)) - for i, raw := range assembled { - sockFilters[i] = unix.SockFilter{ - Code: raw.Op, - Jt: raw.Jt, - Jf: raw.Jf, - K: raw.K, - } - } - - // Attach socket filter to the TCP listener - sockFprog := &SockFprog{ - Len: uint16(len(sockFilters)), - Filter: &sockFilters[0], - } - - fd := int(file.Fd()) - _, _, errno := syscall.Syscall6( - unix.SYS_SETSOCKOPT, - uintptr(fd), - uintptr(unix.SOL_SOCKET), - uintptr(unix.SO_ATTACH_FILTER), - uintptr(unsafe.Pointer(sockFprog)), - unsafe.Sizeof(*sockFprog), - 0, - ) - if errno != 0 { - file.Close() - return fmt.Errorf("attach socket filter: %v", errno) - } - - // Store the file descriptor and file for later detach - filterMutex.Lock() - listenerFilters[tcpListener] = &filterInfo{ - fd: fd, - file: file, - } - filterMutex.Unlock() - - log.Debugf("SSH socket filter attached: restricting to interface index %d", wgIfIndex) - return nil -} - -// createInterfaceFilterProgram creates a BPF program that accepts packets -// only from the specified interface index -func createInterfaceFilterProgram(wgIfIndex uint32) ([]bpf.Instruction, error) { - return []bpf.Instruction{ - // Load interface index from socket metadata - // ExtInterfaceIndex is a special BPF extension for interface index - bpf.LoadExtension{Num: bpf.ExtInterfaceIndex}, - - // Compare with WireGuard interface index - bpf.JumpIf{ - Cond: bpf.JumpEqual, - Val: wgIfIndex, - SkipTrue: 1, - }, - - // Reject if not matching (return 0) - bpf.RetConstant{Val: 0}, - - // Accept if matching (return maximum packet size) - bpf.RetConstant{Val: 0xFFFFFFFF}, - }, nil -} - -// detachSocketFilter removes the socket filter from a TCP listener -func detachSocketFilter(listener net.Listener) error { - tcpListener, ok := listener.(*net.TCPListener) - if !ok { - return fmt.Errorf("listener is not a TCP listener") - } - - filterMutex.Lock() - info, exists := listenerFilters[tcpListener] - if exists { - delete(listenerFilters, tcpListener) - } - filterMutex.Unlock() - - if !exists { - log.Debugf("No socket filter attached to detach") - return nil - } - - defer func() { - if closeErr := info.file.Close(); closeErr != nil { - log.Debugf("listener file close error: %v", closeErr) - } - }() - - // Use the same file descriptor that was used for attach - if err := unix.SetsockoptInt(info.fd, unix.SOL_SOCKET, unix.SO_DETACH_FILTER, 0); err != nil { - return fmt.Errorf("detach socket filter: %w", err) - } - - log.Debugf("SSH socket filter detached") - return nil -} diff --git a/client/ssh/server/socket_filter_nonlinux.go b/client/ssh/server/socket_filter_nonlinux.go deleted file mode 100644 index a52e15ef2..000000000 --- a/client/ssh/server/socket_filter_nonlinux.go +++ /dev/null @@ -1,19 +0,0 @@ -//go:build !linux - -package server - -import ( - "net" -) - -// attachSocketFilter is not supported on non-Linux platforms -func attachSocketFilter(listener net.Listener, wgIfIndex int) error { - // Socket filtering is not available on non-Linux platforms - no-op - return nil -} - -// detachSocketFilter is not supported on non-Linux platforms -func detachSocketFilter(listener net.Listener) error { - // Socket filtering is not available on non-Linux platforms - no-op - return nil -} diff --git a/client/ssh/server/socket_filter_test.go b/client/ssh/server/socket_filter_test.go deleted file mode 100644 index 624aef3a1..000000000 --- a/client/ssh/server/socket_filter_test.go +++ /dev/null @@ -1,160 +0,0 @@ -//go:build linux - -package server - -import ( - "net" - "testing" - - "github.com/stretchr/testify/require" - "golang.org/x/net/bpf" -) - -func TestCreateInterfaceFilterProgram(t *testing.T) { - wgIfIndex := uint32(42) - - prog, err := createInterfaceFilterProgram(wgIfIndex) - require.NoError(t, err, "Should create BPF program without error") - require.NotEmpty(t, prog, "BPF program should not be empty") - - // Verify program structure - require.Len(t, prog, 4, "BPF program should have 4 instructions") - - // Check first instruction - load interface index - loadExt, ok := prog[0].(bpf.LoadExtension) - require.True(t, ok, "First instruction should be LoadExtension") - require.Equal(t, bpf.ExtInterfaceIndex, loadExt.Num, "Should load interface index extension") - - // Check second instruction - compare with target interface - jumpIf, ok := prog[1].(bpf.JumpIf) - require.True(t, ok, "Second instruction should be JumpIf") - require.Equal(t, bpf.JumpEqual, jumpIf.Cond, "Should compare for equality") - require.Equal(t, wgIfIndex, jumpIf.Val, "Should compare with correct interface index") - require.Equal(t, uint8(1), jumpIf.SkipTrue, "Should skip next instruction if match") - - // Check third instruction - reject if not matching - rejectRet, ok := prog[2].(bpf.RetConstant) - require.True(t, ok, "Third instruction should be RetConstant") - require.Equal(t, uint32(0), rejectRet.Val, "Should return 0 to reject packet") - - // Check fourth instruction - accept if matching - acceptRet, ok := prog[3].(bpf.RetConstant) - require.True(t, ok, "Fourth instruction should be RetConstant") - require.Equal(t, uint32(0xFFFFFFFF), acceptRet.Val, "Should return max value to accept packet") -} - -func TestCreateInterfaceFilterProgram_Assembly(t *testing.T) { - wgIfIndex := uint32(10) - - prog, err := createInterfaceFilterProgram(wgIfIndex) - require.NoError(t, err, "Should create BPF program without error") - - // Test that the program can be assembled - assembled, err := bpf.Assemble(prog) - require.NoError(t, err, "BPF program should assemble without error") - require.NotEmpty(t, assembled, "Assembled program should not be empty") - require.True(t, len(assembled) > 0, "Should produce non-empty assembled instructions") -} - -func TestAttachSocketFilter_NonTCPListener(t *testing.T) { - // Create a mock listener that's not a TCP listener - mockListener := &mockFilterListener{} - defer mockListener.Close() - - err := attachSocketFilter(mockListener, 1) - require.Error(t, err, "Should return error for non-TCP listener") - require.Contains(t, err.Error(), "not a TCP listener", "Error should indicate listener type issue") -} - -// mockFilterListener implements net.Listener but is not a TCP listener -type mockFilterListener struct{} - -func (m *mockFilterListener) Accept() (net.Conn, error) { - return nil, net.ErrClosed -} - -func (m *mockFilterListener) Close() error { - return nil -} - -func (m *mockFilterListener) Addr() net.Addr { - addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:0") - return addr -} - -func TestAttachSocketFilter_Integration(t *testing.T) { - // Create a test TCP listener - tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0") - require.NoError(t, err, "Should resolve TCP address") - - tcpListener, err := net.ListenTCP("tcp", tcpAddr) - require.NoError(t, err, "Should create TCP listener") - defer func() { - if closeErr := tcpListener.Close(); closeErr != nil { - t.Logf("TCP listener close error: %v", closeErr) - } - }() - - // Get a real interface for testing - interfaces, err := net.Interfaces() - require.NoError(t, err, "Should get network interfaces") - require.NotEmpty(t, interfaces, "Should have at least one network interface") - - // Use the first non-loopback interface - var testIfIndex int - for _, iface := range interfaces { - if iface.Flags&net.FlagLoopback == 0 && iface.Index > 0 { - testIfIndex = iface.Index - break - } - } - - if testIfIndex == 0 { - t.Skip("No suitable network interface found for testing") - } - - // Test socket filter attachment - err = attachSocketFilter(tcpListener, testIfIndex) - if err != nil { - // Socket filter attachment may fail in test environments due to permissions - // This is expected and acceptable - t.Logf("Socket filter attachment failed (expected in test environment): %v", err) - return - } - - // If attachment succeeded, test detachment - err = detachSocketFilter(tcpListener) - if err != nil { - // Detachment may fail in test environments due to socket state changes - t.Logf("Socket filter detachment failed (expected in test environment): %v", err) - } -} - -func TestSetSocketFilter_Integration(t *testing.T) { - testKey := []byte(`-----BEGIN OPENSSH PRIVATE KEY----- -b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAFwAAAAdzc2gtcn -NhAAAAAwEAAQAAAQEA2Z3QY0EfAFU+wU1M7FH+6QCPfZhL1H5ZbG5QZ4oP+H8Y7QJYbY -rNYmY+x2G5nU1J5T1x6QaKv8Y5Yx8gKQBz5vBV7V3X9UY1QY0EfAFU+wU1M7FH+6QCP -fZhL1H5ZbG5QZ4oP+H8Y7QJYbYrNYmY+x2G5nU1J5T1x6QaKv8Y5Yx8gKQBz5vBV7V3X -9UY1QY0EfAFU+wU1M7FH+6QCPfZhL1H5ZbG5QZ4oP+H8Y7QJYbYrNYmY+x2G5nU1J5T -1x6QaKv8Y5Yx8gKQBz5vBV7V3X9UY1QY0EfAFU+wU1M7FH+6QCPfZhL1H5ZbG5QZ4oP -+H8Y7QJYbYrNYmY+x2G5nU1J5T1x6QaKv8Y5Yx8gKQBz5vBV7V3X9UAAAA8g+QKV7Ps -ClezwAAAAAABBAAAAdwdwdF9rZXlfc2VjcmV0AAAAAQAAAQEA2Z3QY0EfAFU+wU1M7FH+ -6QCPfZhL1H5ZbG5QZ4oP+H8Y7QJYbYrNYmY+x2G5nU1J5T1x6QaKv8Y5Yx8gKQBz5vBV -7V3X9UY1QY0EfAFU+wU1M7FH+6QCPfZhL1H5ZbG5QZ4oP+H8Y7QJYbYrNYmY+x2G5nU -1J5T1x6QaKv8Y5Yx8gKQBz5vBV7V3X9UY1QY0EfAFU+wU1M7FH+6QCPfZhL1H5ZbG5Q -Z4oP+H8Y7QJYbYrNYmY+x2G5nU1J5T1x6QaKv8Y5Yx8gKQBz5vBV7V3X9UY1QY0EfAF -U+wU1M7FH+6QCPfZhL1H5ZbG5QZ4oP+H8Y7QJYbYrNYmY+x2G5nU1J5T1x6QaKv8Y5Y -x8gKQBz5vBV7V3X9UAAAA8g+QKV7PsClezwAAA= ------END OPENSSH PRIVATE KEY-----`) - - server := New(testKey) - require.NotNil(t, server, "Should create SSH server") - - // Test SetSocketFilter method - testIfIndex := 42 - server.SetSocketFilter(testIfIndex) - - // Verify the socket filter configuration was stored - require.Equal(t, testIfIndex, server.ifIdx, "Should store correct interface index") -}