mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-06 17:08:53 +00:00
Remove socketfilter temporarily
This commit is contained in:
@@ -139,11 +139,6 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||
return fmt.Errorf("create listener: %w", err)
|
||||
}
|
||||
|
||||
if err := s.setupSocketFilter(ln); err != nil {
|
||||
s.closeListener(ln)
|
||||
return fmt.Errorf("setup socket filter: %w", err)
|
||||
}
|
||||
|
||||
sshServer, err := s.createSSHServer(ln)
|
||||
if err != nil {
|
||||
s.cleanupOnError(ln)
|
||||
@@ -176,14 +171,6 @@ func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.L
|
||||
return ln, addr.String(), nil
|
||||
}
|
||||
|
||||
// setupSocketFilter attaches socket filter if needed
|
||||
func (s *Server) setupSocketFilter(ln net.Listener) error {
|
||||
if s.ifIdx == 0 || ln == nil || s.netstackNet != nil {
|
||||
return nil
|
||||
}
|
||||
return attachSocketFilter(ln, s.ifIdx)
|
||||
}
|
||||
|
||||
// closeListener safely closes a listener
|
||||
func (s *Server) closeListener(ln net.Listener) {
|
||||
if err := ln.Close(); err != nil {
|
||||
@@ -197,9 +184,6 @@ func (s *Server) cleanupOnError(ln net.Listener) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := detachSocketFilter(ln); err != nil {
|
||||
log.Errorf("failed to detach socket filter: %v", err)
|
||||
}
|
||||
s.closeListener(ln)
|
||||
}
|
||||
|
||||
@@ -218,13 +202,6 @@ func (s *Server) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.ifIdx > 0 && s.listener != nil {
|
||||
if err := detachSocketFilter(s.listener); err != nil {
|
||||
// without detaching the filter, the listener will block on shutdown
|
||||
return fmt.Errorf("detach socket filter: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.sshServer.Close(); err != nil && !isShutdownError(err) {
|
||||
return fmt.Errorf("shutdown SSH server: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/bpf"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// SockFprog represents a BPF program for socket filtering
|
||||
type SockFprog struct {
|
||||
Len uint16
|
||||
Filter *unix.SockFilter
|
||||
}
|
||||
|
||||
// filterInfo stores the file descriptor and filter state for each listener
|
||||
type filterInfo struct {
|
||||
fd int
|
||||
file *os.File
|
||||
}
|
||||
|
||||
var (
|
||||
listenerFilters = make(map[*net.TCPListener]*filterInfo)
|
||||
filterMutex sync.RWMutex
|
||||
)
|
||||
|
||||
// attachSocketFilter attaches a BPF socket filter to restrict SSH connections
|
||||
// to only the specified WireGuard interface index
|
||||
func attachSocketFilter(listener net.Listener, wgIfIndex int) error {
|
||||
tcpListener, ok := listener.(*net.TCPListener)
|
||||
if !ok {
|
||||
return fmt.Errorf("listener is not a TCP listener")
|
||||
}
|
||||
|
||||
file, err := tcpListener.File()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get listener file descriptor: %w", err)
|
||||
}
|
||||
// Don't close the file here - we need it for detaching the filter
|
||||
|
||||
// Set the duplicated FD to non-blocking to match the mode of the
|
||||
// FD used by the Go runtime's network poller
|
||||
if err := syscall.SetNonblock(int(file.Fd()), true); err != nil {
|
||||
file.Close()
|
||||
return fmt.Errorf("set non-blocking on duplicated FD: %w", err)
|
||||
}
|
||||
|
||||
// Create BPF program that filters by interface index
|
||||
prog, err := createInterfaceFilterProgram(uint32(wgIfIndex))
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return fmt.Errorf("create BPF program: %w", err)
|
||||
}
|
||||
|
||||
assembled, err := bpf.Assemble(prog)
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return fmt.Errorf("assemble BPF program: %w", err)
|
||||
}
|
||||
|
||||
// Convert to unix.SockFilter format
|
||||
sockFilters := make([]unix.SockFilter, len(assembled))
|
||||
for i, raw := range assembled {
|
||||
sockFilters[i] = unix.SockFilter{
|
||||
Code: raw.Op,
|
||||
Jt: raw.Jt,
|
||||
Jf: raw.Jf,
|
||||
K: raw.K,
|
||||
}
|
||||
}
|
||||
|
||||
// Attach socket filter to the TCP listener
|
||||
sockFprog := &SockFprog{
|
||||
Len: uint16(len(sockFilters)),
|
||||
Filter: &sockFilters[0],
|
||||
}
|
||||
|
||||
fd := int(file.Fd())
|
||||
_, _, errno := syscall.Syscall6(
|
||||
unix.SYS_SETSOCKOPT,
|
||||
uintptr(fd),
|
||||
uintptr(unix.SOL_SOCKET),
|
||||
uintptr(unix.SO_ATTACH_FILTER),
|
||||
uintptr(unsafe.Pointer(sockFprog)),
|
||||
unsafe.Sizeof(*sockFprog),
|
||||
0,
|
||||
)
|
||||
if errno != 0 {
|
||||
file.Close()
|
||||
return fmt.Errorf("attach socket filter: %v", errno)
|
||||
}
|
||||
|
||||
// Store the file descriptor and file for later detach
|
||||
filterMutex.Lock()
|
||||
listenerFilters[tcpListener] = &filterInfo{
|
||||
fd: fd,
|
||||
file: file,
|
||||
}
|
||||
filterMutex.Unlock()
|
||||
|
||||
log.Debugf("SSH socket filter attached: restricting to interface index %d", wgIfIndex)
|
||||
return nil
|
||||
}
|
||||
|
||||
// createInterfaceFilterProgram creates a BPF program that accepts packets
|
||||
// only from the specified interface index
|
||||
func createInterfaceFilterProgram(wgIfIndex uint32) ([]bpf.Instruction, error) {
|
||||
return []bpf.Instruction{
|
||||
// Load interface index from socket metadata
|
||||
// ExtInterfaceIndex is a special BPF extension for interface index
|
||||
bpf.LoadExtension{Num: bpf.ExtInterfaceIndex},
|
||||
|
||||
// Compare with WireGuard interface index
|
||||
bpf.JumpIf{
|
||||
Cond: bpf.JumpEqual,
|
||||
Val: wgIfIndex,
|
||||
SkipTrue: 1,
|
||||
},
|
||||
|
||||
// Reject if not matching (return 0)
|
||||
bpf.RetConstant{Val: 0},
|
||||
|
||||
// Accept if matching (return maximum packet size)
|
||||
bpf.RetConstant{Val: 0xFFFFFFFF},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// detachSocketFilter removes the socket filter from a TCP listener
|
||||
func detachSocketFilter(listener net.Listener) error {
|
||||
tcpListener, ok := listener.(*net.TCPListener)
|
||||
if !ok {
|
||||
return fmt.Errorf("listener is not a TCP listener")
|
||||
}
|
||||
|
||||
filterMutex.Lock()
|
||||
info, exists := listenerFilters[tcpListener]
|
||||
if exists {
|
||||
delete(listenerFilters, tcpListener)
|
||||
}
|
||||
filterMutex.Unlock()
|
||||
|
||||
if !exists {
|
||||
log.Debugf("No socket filter attached to detach")
|
||||
return nil
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if closeErr := info.file.Close(); closeErr != nil {
|
||||
log.Debugf("listener file close error: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
// Use the same file descriptor that was used for attach
|
||||
if err := unix.SetsockoptInt(info.fd, unix.SOL_SOCKET, unix.SO_DETACH_FILTER, 0); err != nil {
|
||||
return fmt.Errorf("detach socket filter: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("SSH socket filter detached")
|
||||
return nil
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// attachSocketFilter is not supported on non-Linux platforms
|
||||
func attachSocketFilter(listener net.Listener, wgIfIndex int) error {
|
||||
// Socket filtering is not available on non-Linux platforms - no-op
|
||||
return nil
|
||||
}
|
||||
|
||||
// detachSocketFilter is not supported on non-Linux platforms
|
||||
func detachSocketFilter(listener net.Listener) error {
|
||||
// Socket filtering is not available on non-Linux platforms - no-op
|
||||
return nil
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/bpf"
|
||||
)
|
||||
|
||||
func TestCreateInterfaceFilterProgram(t *testing.T) {
|
||||
wgIfIndex := uint32(42)
|
||||
|
||||
prog, err := createInterfaceFilterProgram(wgIfIndex)
|
||||
require.NoError(t, err, "Should create BPF program without error")
|
||||
require.NotEmpty(t, prog, "BPF program should not be empty")
|
||||
|
||||
// Verify program structure
|
||||
require.Len(t, prog, 4, "BPF program should have 4 instructions")
|
||||
|
||||
// Check first instruction - load interface index
|
||||
loadExt, ok := prog[0].(bpf.LoadExtension)
|
||||
require.True(t, ok, "First instruction should be LoadExtension")
|
||||
require.Equal(t, bpf.ExtInterfaceIndex, loadExt.Num, "Should load interface index extension")
|
||||
|
||||
// Check second instruction - compare with target interface
|
||||
jumpIf, ok := prog[1].(bpf.JumpIf)
|
||||
require.True(t, ok, "Second instruction should be JumpIf")
|
||||
require.Equal(t, bpf.JumpEqual, jumpIf.Cond, "Should compare for equality")
|
||||
require.Equal(t, wgIfIndex, jumpIf.Val, "Should compare with correct interface index")
|
||||
require.Equal(t, uint8(1), jumpIf.SkipTrue, "Should skip next instruction if match")
|
||||
|
||||
// Check third instruction - reject if not matching
|
||||
rejectRet, ok := prog[2].(bpf.RetConstant)
|
||||
require.True(t, ok, "Third instruction should be RetConstant")
|
||||
require.Equal(t, uint32(0), rejectRet.Val, "Should return 0 to reject packet")
|
||||
|
||||
// Check fourth instruction - accept if matching
|
||||
acceptRet, ok := prog[3].(bpf.RetConstant)
|
||||
require.True(t, ok, "Fourth instruction should be RetConstant")
|
||||
require.Equal(t, uint32(0xFFFFFFFF), acceptRet.Val, "Should return max value to accept packet")
|
||||
}
|
||||
|
||||
func TestCreateInterfaceFilterProgram_Assembly(t *testing.T) {
|
||||
wgIfIndex := uint32(10)
|
||||
|
||||
prog, err := createInterfaceFilterProgram(wgIfIndex)
|
||||
require.NoError(t, err, "Should create BPF program without error")
|
||||
|
||||
// Test that the program can be assembled
|
||||
assembled, err := bpf.Assemble(prog)
|
||||
require.NoError(t, err, "BPF program should assemble without error")
|
||||
require.NotEmpty(t, assembled, "Assembled program should not be empty")
|
||||
require.True(t, len(assembled) > 0, "Should produce non-empty assembled instructions")
|
||||
}
|
||||
|
||||
func TestAttachSocketFilter_NonTCPListener(t *testing.T) {
|
||||
// Create a mock listener that's not a TCP listener
|
||||
mockListener := &mockFilterListener{}
|
||||
defer mockListener.Close()
|
||||
|
||||
err := attachSocketFilter(mockListener, 1)
|
||||
require.Error(t, err, "Should return error for non-TCP listener")
|
||||
require.Contains(t, err.Error(), "not a TCP listener", "Error should indicate listener type issue")
|
||||
}
|
||||
|
||||
// mockFilterListener implements net.Listener but is not a TCP listener
|
||||
type mockFilterListener struct{}
|
||||
|
||||
func (m *mockFilterListener) Accept() (net.Conn, error) {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
|
||||
func (m *mockFilterListener) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockFilterListener) Addr() net.Addr {
|
||||
addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
|
||||
return addr
|
||||
}
|
||||
|
||||
func TestAttachSocketFilter_Integration(t *testing.T) {
|
||||
// Create a test TCP listener
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "Should resolve TCP address")
|
||||
|
||||
tcpListener, err := net.ListenTCP("tcp", tcpAddr)
|
||||
require.NoError(t, err, "Should create TCP listener")
|
||||
defer func() {
|
||||
if closeErr := tcpListener.Close(); closeErr != nil {
|
||||
t.Logf("TCP listener close error: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
// Get a real interface for testing
|
||||
interfaces, err := net.Interfaces()
|
||||
require.NoError(t, err, "Should get network interfaces")
|
||||
require.NotEmpty(t, interfaces, "Should have at least one network interface")
|
||||
|
||||
// Use the first non-loopback interface
|
||||
var testIfIndex int
|
||||
for _, iface := range interfaces {
|
||||
if iface.Flags&net.FlagLoopback == 0 && iface.Index > 0 {
|
||||
testIfIndex = iface.Index
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if testIfIndex == 0 {
|
||||
t.Skip("No suitable network interface found for testing")
|
||||
}
|
||||
|
||||
// Test socket filter attachment
|
||||
err = attachSocketFilter(tcpListener, testIfIndex)
|
||||
if err != nil {
|
||||
// Socket filter attachment may fail in test environments due to permissions
|
||||
// This is expected and acceptable
|
||||
t.Logf("Socket filter attachment failed (expected in test environment): %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If attachment succeeded, test detachment
|
||||
err = detachSocketFilter(tcpListener)
|
||||
if err != nil {
|
||||
// Detachment may fail in test environments due to socket state changes
|
||||
t.Logf("Socket filter detachment failed (expected in test environment): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetSocketFilter_Integration(t *testing.T) {
|
||||
testKey := []byte(`-----BEGIN OPENSSH PRIVATE KEY-----
|
||||
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAFwAAAAdzc2gtcn
|
||||
NhAAAAAwEAAQAAAQEA2Z3QY0EfAFU+wU1M7FH+6QCPfZhL1H5ZbG5QZ4oP+H8Y7QJYbY
|
||||
rNYmY+x2G5nU1J5T1x6QaKv8Y5Yx8gKQBz5vBV7V3X9UY1QY0EfAFU+wU1M7FH+6QCP
|
||||
fZhL1H5ZbG5QZ4oP+H8Y7QJYbYrNYmY+x2G5nU1J5T1x6QaKv8Y5Yx8gKQBz5vBV7V3X
|
||||
9UY1QY0EfAFU+wU1M7FH+6QCPfZhL1H5ZbG5QZ4oP+H8Y7QJYbYrNYmY+x2G5nU1J5T
|
||||
1x6QaKv8Y5Yx8gKQBz5vBV7V3X9UY1QY0EfAFU+wU1M7FH+6QCPfZhL1H5ZbG5QZ4oP
|
||||
+H8Y7QJYbYrNYmY+x2G5nU1J5T1x6QaKv8Y5Yx8gKQBz5vBV7V3X9UAAAA8g+QKV7Ps
|
||||
ClezwAAAAAABBAAAAdwdwdF9rZXlfc2VjcmV0AAAAAQAAAQEA2Z3QY0EfAFU+wU1M7FH+
|
||||
6QCPfZhL1H5ZbG5QZ4oP+H8Y7QJYbYrNYmY+x2G5nU1J5T1x6QaKv8Y5Yx8gKQBz5vBV
|
||||
7V3X9UY1QY0EfAFU+wU1M7FH+6QCPfZhL1H5ZbG5QZ4oP+H8Y7QJYbYrNYmY+x2G5nU
|
||||
1J5T1x6QaKv8Y5Yx8gKQBz5vBV7V3X9UY1QY0EfAFU+wU1M7FH+6QCPfZhL1H5ZbG5Q
|
||||
Z4oP+H8Y7QJYbYrNYmY+x2G5nU1J5T1x6QaKv8Y5Yx8gKQBz5vBV7V3X9UY1QY0EfAF
|
||||
U+wU1M7FH+6QCPfZhL1H5ZbG5QZ4oP+H8Y7QJYbYrNYmY+x2G5nU1J5T1x6QaKv8Y5Y
|
||||
x8gKQBz5vBV7V3X9UAAAA8g+QKV7PsClezwAAA=
|
||||
-----END OPENSSH PRIVATE KEY-----`)
|
||||
|
||||
server := New(testKey)
|
||||
require.NotNil(t, server, "Should create SSH server")
|
||||
|
||||
// Test SetSocketFilter method
|
||||
testIfIndex := 42
|
||||
server.SetSocketFilter(testIfIndex)
|
||||
|
||||
// Verify the socket filter configuration was stored
|
||||
require.Equal(t, testIfIndex, server.ifIdx, "Should store correct interface index")
|
||||
}
|
||||
Reference in New Issue
Block a user