Remove socketfilter temporarily

This commit is contained in:
Viktor Liu
2025-07-02 22:00:10 +02:00
parent 1fdde66c31
commit 612de2c784
5 changed files with 0 additions and 375 deletions

View File

@@ -139,11 +139,6 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
return fmt.Errorf("create listener: %w", err)
}
if err := s.setupSocketFilter(ln); err != nil {
s.closeListener(ln)
return fmt.Errorf("setup socket filter: %w", err)
}
sshServer, err := s.createSSHServer(ln)
if err != nil {
s.cleanupOnError(ln)
@@ -176,14 +171,6 @@ func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.L
return ln, addr.String(), nil
}
// setupSocketFilter attaches socket filter if needed
func (s *Server) setupSocketFilter(ln net.Listener) error {
if s.ifIdx == 0 || ln == nil || s.netstackNet != nil {
return nil
}
return attachSocketFilter(ln, s.ifIdx)
}
// closeListener safely closes a listener
func (s *Server) closeListener(ln net.Listener) {
if err := ln.Close(); err != nil {
@@ -197,9 +184,6 @@ func (s *Server) cleanupOnError(ln net.Listener) {
return
}
if err := detachSocketFilter(ln); err != nil {
log.Errorf("failed to detach socket filter: %v", err)
}
s.closeListener(ln)
}
@@ -218,13 +202,6 @@ func (s *Server) Stop() error {
return nil
}
if s.ifIdx > 0 && s.listener != nil {
if err := detachSocketFilter(s.listener); err != nil {
// without detaching the filter, the listener will block on shutdown
return fmt.Errorf("detach socket filter: %w", err)
}
}
if err := s.sshServer.Close(); err != nil && !isShutdownError(err) {
return fmt.Errorf("shutdown SSH server: %w", err)
}

View File

@@ -1,168 +0,0 @@
//go:build linux
package server
import (
"fmt"
"net"
"os"
"sync"
"syscall"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/net/bpf"
"golang.org/x/sys/unix"
)
// SockFprog represents a BPF program for socket filtering
type SockFprog struct {
Len uint16
Filter *unix.SockFilter
}
// filterInfo stores the file descriptor and filter state for each listener
type filterInfo struct {
fd int
file *os.File
}
var (
listenerFilters = make(map[*net.TCPListener]*filterInfo)
filterMutex sync.RWMutex
)
// attachSocketFilter attaches a BPF socket filter to restrict SSH connections
// to only the specified WireGuard interface index
func attachSocketFilter(listener net.Listener, wgIfIndex int) error {
tcpListener, ok := listener.(*net.TCPListener)
if !ok {
return fmt.Errorf("listener is not a TCP listener")
}
file, err := tcpListener.File()
if err != nil {
return fmt.Errorf("get listener file descriptor: %w", err)
}
// Don't close the file here - we need it for detaching the filter
// Set the duplicated FD to non-blocking to match the mode of the
// FD used by the Go runtime's network poller
if err := syscall.SetNonblock(int(file.Fd()), true); err != nil {
file.Close()
return fmt.Errorf("set non-blocking on duplicated FD: %w", err)
}
// Create BPF program that filters by interface index
prog, err := createInterfaceFilterProgram(uint32(wgIfIndex))
if err != nil {
file.Close()
return fmt.Errorf("create BPF program: %w", err)
}
assembled, err := bpf.Assemble(prog)
if err != nil {
file.Close()
return fmt.Errorf("assemble BPF program: %w", err)
}
// Convert to unix.SockFilter format
sockFilters := make([]unix.SockFilter, len(assembled))
for i, raw := range assembled {
sockFilters[i] = unix.SockFilter{
Code: raw.Op,
Jt: raw.Jt,
Jf: raw.Jf,
K: raw.K,
}
}
// Attach socket filter to the TCP listener
sockFprog := &SockFprog{
Len: uint16(len(sockFilters)),
Filter: &sockFilters[0],
}
fd := int(file.Fd())
_, _, errno := syscall.Syscall6(
unix.SYS_SETSOCKOPT,
uintptr(fd),
uintptr(unix.SOL_SOCKET),
uintptr(unix.SO_ATTACH_FILTER),
uintptr(unsafe.Pointer(sockFprog)),
unsafe.Sizeof(*sockFprog),
0,
)
if errno != 0 {
file.Close()
return fmt.Errorf("attach socket filter: %v", errno)
}
// Store the file descriptor and file for later detach
filterMutex.Lock()
listenerFilters[tcpListener] = &filterInfo{
fd: fd,
file: file,
}
filterMutex.Unlock()
log.Debugf("SSH socket filter attached: restricting to interface index %d", wgIfIndex)
return nil
}
// createInterfaceFilterProgram creates a BPF program that accepts packets
// only from the specified interface index
func createInterfaceFilterProgram(wgIfIndex uint32) ([]bpf.Instruction, error) {
return []bpf.Instruction{
// Load interface index from socket metadata
// ExtInterfaceIndex is a special BPF extension for interface index
bpf.LoadExtension{Num: bpf.ExtInterfaceIndex},
// Compare with WireGuard interface index
bpf.JumpIf{
Cond: bpf.JumpEqual,
Val: wgIfIndex,
SkipTrue: 1,
},
// Reject if not matching (return 0)
bpf.RetConstant{Val: 0},
// Accept if matching (return maximum packet size)
bpf.RetConstant{Val: 0xFFFFFFFF},
}, nil
}
// detachSocketFilter removes the socket filter from a TCP listener
func detachSocketFilter(listener net.Listener) error {
tcpListener, ok := listener.(*net.TCPListener)
if !ok {
return fmt.Errorf("listener is not a TCP listener")
}
filterMutex.Lock()
info, exists := listenerFilters[tcpListener]
if exists {
delete(listenerFilters, tcpListener)
}
filterMutex.Unlock()
if !exists {
log.Debugf("No socket filter attached to detach")
return nil
}
defer func() {
if closeErr := info.file.Close(); closeErr != nil {
log.Debugf("listener file close error: %v", closeErr)
}
}()
// Use the same file descriptor that was used for attach
if err := unix.SetsockoptInt(info.fd, unix.SOL_SOCKET, unix.SO_DETACH_FILTER, 0); err != nil {
return fmt.Errorf("detach socket filter: %w", err)
}
log.Debugf("SSH socket filter detached")
return nil
}

View File

@@ -1,19 +0,0 @@
//go:build !linux
package server
import (
"net"
)
// attachSocketFilter is not supported on non-Linux platforms
func attachSocketFilter(listener net.Listener, wgIfIndex int) error {
// Socket filtering is not available on non-Linux platforms - no-op
return nil
}
// detachSocketFilter is not supported on non-Linux platforms
func detachSocketFilter(listener net.Listener) error {
// Socket filtering is not available on non-Linux platforms - no-op
return nil
}

View File

@@ -1,160 +0,0 @@
//go:build linux
package server
import (
"net"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/net/bpf"
)
func TestCreateInterfaceFilterProgram(t *testing.T) {
wgIfIndex := uint32(42)
prog, err := createInterfaceFilterProgram(wgIfIndex)
require.NoError(t, err, "Should create BPF program without error")
require.NotEmpty(t, prog, "BPF program should not be empty")
// Verify program structure
require.Len(t, prog, 4, "BPF program should have 4 instructions")
// Check first instruction - load interface index
loadExt, ok := prog[0].(bpf.LoadExtension)
require.True(t, ok, "First instruction should be LoadExtension")
require.Equal(t, bpf.ExtInterfaceIndex, loadExt.Num, "Should load interface index extension")
// Check second instruction - compare with target interface
jumpIf, ok := prog[1].(bpf.JumpIf)
require.True(t, ok, "Second instruction should be JumpIf")
require.Equal(t, bpf.JumpEqual, jumpIf.Cond, "Should compare for equality")
require.Equal(t, wgIfIndex, jumpIf.Val, "Should compare with correct interface index")
require.Equal(t, uint8(1), jumpIf.SkipTrue, "Should skip next instruction if match")
// Check third instruction - reject if not matching
rejectRet, ok := prog[2].(bpf.RetConstant)
require.True(t, ok, "Third instruction should be RetConstant")
require.Equal(t, uint32(0), rejectRet.Val, "Should return 0 to reject packet")
// Check fourth instruction - accept if matching
acceptRet, ok := prog[3].(bpf.RetConstant)
require.True(t, ok, "Fourth instruction should be RetConstant")
require.Equal(t, uint32(0xFFFFFFFF), acceptRet.Val, "Should return max value to accept packet")
}
func TestCreateInterfaceFilterProgram_Assembly(t *testing.T) {
wgIfIndex := uint32(10)
prog, err := createInterfaceFilterProgram(wgIfIndex)
require.NoError(t, err, "Should create BPF program without error")
// Test that the program can be assembled
assembled, err := bpf.Assemble(prog)
require.NoError(t, err, "BPF program should assemble without error")
require.NotEmpty(t, assembled, "Assembled program should not be empty")
require.True(t, len(assembled) > 0, "Should produce non-empty assembled instructions")
}
func TestAttachSocketFilter_NonTCPListener(t *testing.T) {
// Create a mock listener that's not a TCP listener
mockListener := &mockFilterListener{}
defer mockListener.Close()
err := attachSocketFilter(mockListener, 1)
require.Error(t, err, "Should return error for non-TCP listener")
require.Contains(t, err.Error(), "not a TCP listener", "Error should indicate listener type issue")
}
// mockFilterListener implements net.Listener but is not a TCP listener
type mockFilterListener struct{}
func (m *mockFilterListener) Accept() (net.Conn, error) {
return nil, net.ErrClosed
}
func (m *mockFilterListener) Close() error {
return nil
}
func (m *mockFilterListener) Addr() net.Addr {
addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
return addr
}
func TestAttachSocketFilter_Integration(t *testing.T) {
// Create a test TCP listener
tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
require.NoError(t, err, "Should resolve TCP address")
tcpListener, err := net.ListenTCP("tcp", tcpAddr)
require.NoError(t, err, "Should create TCP listener")
defer func() {
if closeErr := tcpListener.Close(); closeErr != nil {
t.Logf("TCP listener close error: %v", closeErr)
}
}()
// Get a real interface for testing
interfaces, err := net.Interfaces()
require.NoError(t, err, "Should get network interfaces")
require.NotEmpty(t, interfaces, "Should have at least one network interface")
// Use the first non-loopback interface
var testIfIndex int
for _, iface := range interfaces {
if iface.Flags&net.FlagLoopback == 0 && iface.Index > 0 {
testIfIndex = iface.Index
break
}
}
if testIfIndex == 0 {
t.Skip("No suitable network interface found for testing")
}
// Test socket filter attachment
err = attachSocketFilter(tcpListener, testIfIndex)
if err != nil {
// Socket filter attachment may fail in test environments due to permissions
// This is expected and acceptable
t.Logf("Socket filter attachment failed (expected in test environment): %v", err)
return
}
// If attachment succeeded, test detachment
err = detachSocketFilter(tcpListener)
if err != nil {
// Detachment may fail in test environments due to socket state changes
t.Logf("Socket filter detachment failed (expected in test environment): %v", err)
}
}
func TestSetSocketFilter_Integration(t *testing.T) {
testKey := []byte(`-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAFwAAAAdzc2gtcn
NhAAAAAwEAAQAAAQEA2Z3QY0EfAFU+wU1M7FH+6QCPfZhL1H5ZbG5QZ4oP+H8Y7QJYbY
rNYmY+x2G5nU1J5T1x6QaKv8Y5Yx8gKQBz5vBV7V3X9UY1QY0EfAFU+wU1M7FH+6QCP
fZhL1H5ZbG5QZ4oP+H8Y7QJYbYrNYmY+x2G5nU1J5T1x6QaKv8Y5Yx8gKQBz5vBV7V3X
9UY1QY0EfAFU+wU1M7FH+6QCPfZhL1H5ZbG5QZ4oP+H8Y7QJYbYrNYmY+x2G5nU1J5T
1x6QaKv8Y5Yx8gKQBz5vBV7V3X9UY1QY0EfAFU+wU1M7FH+6QCPfZhL1H5ZbG5QZ4oP
+H8Y7QJYbYrNYmY+x2G5nU1J5T1x6QaKv8Y5Yx8gKQBz5vBV7V3X9UAAAA8g+QKV7Ps
ClezwAAAAAABBAAAAdwdwdF9rZXlfc2VjcmV0AAAAAQAAAQEA2Z3QY0EfAFU+wU1M7FH+
6QCPfZhL1H5ZbG5QZ4oP+H8Y7QJYbYrNYmY+x2G5nU1J5T1x6QaKv8Y5Yx8gKQBz5vBV
7V3X9UY1QY0EfAFU+wU1M7FH+6QCPfZhL1H5ZbG5QZ4oP+H8Y7QJYbYrNYmY+x2G5nU
1J5T1x6QaKv8Y5Yx8gKQBz5vBV7V3X9UY1QY0EfAFU+wU1M7FH+6QCPfZhL1H5ZbG5Q
Z4oP+H8Y7QJYbYrNYmY+x2G5nU1J5T1x6QaKv8Y5Yx8gKQBz5vBV7V3X9UY1QY0EfAF
U+wU1M7FH+6QCPfZhL1H5ZbG5QZ4oP+H8Y7QJYbYrNYmY+x2G5nU1J5T1x6QaKv8Y5Y
x8gKQBz5vBV7V3X9UAAAA8g+QKV7PsClezwAAA=
-----END OPENSSH PRIVATE KEY-----`)
server := New(testKey)
require.NotNil(t, server, "Should create SSH server")
// Test SetSocketFilter method
testIfIndex := 42
server.SetSocketFilter(testIfIndex)
// Verify the socket filter configuration was stored
require.Equal(t, testIfIndex, server.ifIdx, "Should store correct interface index")
}