Feat linux firewall support (#805)

Update the client's engine to apply firewall rules received from the manager (results of ACL policy).
This commit is contained in:
Givi Khojanashvili
2023-05-29 18:00:18 +04:00
committed by GitHub
parent 2eb9a97fee
commit ba7a39a4fc
51 changed files with 4143 additions and 1013 deletions

90
iface/device_wrapper.go Normal file
View File

@@ -0,0 +1,90 @@
package iface
import (
"net"
"sync"
"golang.zx2c4.com/wireguard/tun"
)
// PacketFilter interface for firewall abilities
type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte) bool
// DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte) bool
// SetNetwork of the wireguard interface to which filtering applied
SetNetwork(*net.IPNet)
}
// DeviceWrapper to override Read or Write of packets
type DeviceWrapper struct {
tun.Device
filter PacketFilter
mutex sync.RWMutex
}
// newDeviceWrapper constructor function
func newDeviceWrapper(device tun.Device) *DeviceWrapper {
return &DeviceWrapper{
Device: device,
}
}
// Read wraps read method with filtering feature
func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
return 0, err
}
d.mutex.RLock()
filter := d.filter
d.mutex.RUnlock()
if filter == nil {
return
}
for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
i--
}
}
return n, nil
}
// Write wraps write method with filtering feature
func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RLock()
filter := d.filter
d.mutex.RUnlock()
if filter == nil {
return d.Device.Write(bufs, offset)
}
filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0
for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:]) {
filteredBufs = append(filteredBufs, buf)
dropped++
}
}
n, err := d.Device.Write(filteredBufs, offset)
n += dropped
return n, err
}
// SetFiltering sets packet filter to device
func (d *DeviceWrapper) SetFiltering(filter PacketFilter) {
d.mutex.Lock()
d.filter = filter
d.mutex.Unlock()
}

View File

@@ -0,0 +1,216 @@
package iface
import (
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
mocks "github.com/netbirdio/netbird/iface/mocks"
)
func TestDeviceWrapperRead(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
tun := mocks.NewMockDevice(ctrl)
filter := mocks.NewMockPacketFilter(ctrl)
mockBufs := [][]byte{{}}
mockSizes := []int{0}
mockOffset := 0
t.Run("read ICMP", func(t *testing.T) {
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolICMPv4,
SrcIP: net.IP{192, 168, 0, 1},
DstIP: net.IP{100, 200, 0, 1},
}
icmpLayer := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
Id: 1,
Seq: 1,
}
buffer := gopacket.NewSerializeBuffer()
err := gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{},
ipLayer,
icmpLayer,
)
if err != nil {
t.Errorf("serialize packet: %v", err)
return
}
tun.EXPECT().Read(mockBufs, mockSizes, mockOffset).
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
bufs[0] = buffer.Bytes()
sizes[0] = len(bufs[0])
return 1, nil
})
wrapped := newDeviceWrapper(tun)
bufs := [][]byte{{}}
sizes := []int{0}
offset := 0
n, err := wrapped.Read(bufs, sizes, offset)
if err != nil {
t.Errorf("unexpeted error: %v", err)
return
}
if n != 1 {
t.Errorf("expected n=1, got %d", n)
return
}
})
t.Run("write TCP", func(t *testing.T) {
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolICMPv4,
SrcIP: net.IP{100, 200, 0, 9},
DstIP: net.IP{100, 200, 0, 10},
}
// create TCP layer packet
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(34423),
DstPort: layers.TCPPort(8080),
}
buffer := gopacket.NewSerializeBuffer()
err := gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{},
ipLayer,
tcpLayer,
)
if err != nil {
t.Errorf("serialize packet: %v", err)
return
}
mockBufs[0] = buffer.Bytes()
tun.EXPECT().Write(mockBufs, 0).Return(1, nil)
wrapped := newDeviceWrapper(tun)
bufs := [][]byte{buffer.Bytes()}
n, err := wrapped.Write(bufs, 0)
if err != nil {
t.Errorf("unexpeted error: %v", err)
return
}
if n != 1 {
t.Errorf("expected n=1, got %d", n)
return
}
})
t.Run("drop write UDP package", func(t *testing.T) {
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolICMPv4,
SrcIP: net.IP{100, 200, 0, 11},
DstIP: net.IP{100, 200, 0, 20},
}
// create TCP layer packet
tcpLayer := &layers.UDP{
SrcPort: layers.UDPPort(27278),
DstPort: layers.UDPPort(53),
}
buffer := gopacket.NewSerializeBuffer()
err := gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{},
ipLayer,
tcpLayer,
)
if err != nil {
t.Errorf("serialize packet: %v", err)
return
}
mockBufs = [][]byte{}
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter.EXPECT().DropOutput(gomock.Any()).Return(true)
wrapped := newDeviceWrapper(tun)
wrapped.filter = filter
bufs := [][]byte{buffer.Bytes()}
n, err := wrapped.Write(bufs, 0)
if err != nil {
t.Errorf("unexpeted error: %v", err)
return
}
if n != 0 {
t.Errorf("expected n=1, got %d", n)
return
}
})
t.Run("drop read UDP package", func(t *testing.T) {
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolICMPv4,
SrcIP: net.IP{100, 200, 0, 11},
DstIP: net.IP{100, 200, 0, 20},
}
// create TCP layer packet
tcpLayer := &layers.UDP{
SrcPort: layers.UDPPort(19243),
DstPort: layers.UDPPort(1024),
}
buffer := gopacket.NewSerializeBuffer()
err := gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{},
ipLayer,
tcpLayer,
)
if err != nil {
t.Errorf("serialize packet: %v", err)
return
}
mockBufs := [][]byte{{}}
mockSizes := []int{0}
mockOffset := 0
tun.EXPECT().Read(mockBufs, mockSizes, mockOffset).
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
bufs[0] = buffer.Bytes()
sizes[0] = len(bufs[0])
return 1, nil
})
filter.EXPECT().DropInput(gomock.Any()).Return(true)
wrapped := newDeviceWrapper(tun)
wrapped.filter = filter
bufs := [][]byte{{}}
sizes := []int{0}
offset := 0
n, err := wrapped.Read(bufs, sizes, offset)
if err != nil {
t.Errorf("unexpeted error: %v", err)
return
}
if n != 0 {
t.Errorf("expected n=0, got %d", n)
return
}
})
}

View File

@@ -1,6 +1,7 @@
package iface
import (
"fmt"
"net"
"sync"
"time"
@@ -118,3 +119,17 @@ func (w *WGIface) Close() error {
defer w.mu.Unlock()
return w.tun.Close()
}
// SetFiltering sets packet filters for the userspace impelemntation
func (w *WGIface) SetFiltering(filter PacketFilter) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.tun.wrapper == nil {
return fmt.Errorf("userspace packet filtering not handled on this device")
}
filter.SetNetwork(w.tun.address.Network)
w.tun.wrapper.SetFiltering(filter)
return nil
}

75
iface/mocks/filter.go Normal file
View File

@@ -0,0 +1,75 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter)
// Package mocks is a generated GoMock package.
package mocks
import (
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockPacketFilter is a mock of PacketFilter interface.
type MockPacketFilter struct {
ctrl *gomock.Controller
recorder *MockPacketFilterMockRecorder
}
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
type MockPacketFilterMockRecorder struct {
mock *MockPacketFilter
}
// NewMockPacketFilter creates a new mock instance.
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
mock := &MockPacketFilter{ctrl: ctrl}
mock.recorder = &MockPacketFilterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// DropInput mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropInput indicates an expected call of DropInput.
func (mr *MockPacketFilterMockRecorder) DropInput(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
}
// DropOutput mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropOutput indicates an expected call of DropOutput.
func (mr *MockPacketFilterMockRecorder) DropOutput(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
}
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

152
iface/mocks/tun.go Normal file
View File

@@ -0,0 +1,152 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: golang.zx2c4.com/wireguard/tun (interfaces: Device)
// Package mocks is a generated GoMock package.
package mocks
import (
os "os"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
tun "golang.zx2c4.com/wireguard/tun"
)
// MockDevice is a mock of Device interface.
type MockDevice struct {
ctrl *gomock.Controller
recorder *MockDeviceMockRecorder
}
// MockDeviceMockRecorder is the mock recorder for MockDevice.
type MockDeviceMockRecorder struct {
mock *MockDevice
}
// NewMockDevice creates a new mock instance.
func NewMockDevice(ctrl *gomock.Controller) *MockDevice {
mock := &MockDevice{ctrl: ctrl}
mock.recorder = &MockDeviceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDevice) EXPECT() *MockDeviceMockRecorder {
return m.recorder
}
// BatchSize mocks base method.
func (m *MockDevice) BatchSize() int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BatchSize")
ret0, _ := ret[0].(int)
return ret0
}
// BatchSize indicates an expected call of BatchSize.
func (mr *MockDeviceMockRecorder) BatchSize() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchSize", reflect.TypeOf((*MockDevice)(nil).BatchSize))
}
// Close mocks base method.
func (m *MockDevice) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockDeviceMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDevice)(nil).Close))
}
// Events mocks base method.
func (m *MockDevice) Events() <-chan tun.Event {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Events")
ret0, _ := ret[0].(<-chan tun.Event)
return ret0
}
// Events indicates an expected call of Events.
func (mr *MockDeviceMockRecorder) Events() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Events", reflect.TypeOf((*MockDevice)(nil).Events))
}
// File mocks base method.
func (m *MockDevice) File() *os.File {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "File")
ret0, _ := ret[0].(*os.File)
return ret0
}
// File indicates an expected call of File.
func (mr *MockDeviceMockRecorder) File() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "File", reflect.TypeOf((*MockDevice)(nil).File))
}
// MTU mocks base method.
func (m *MockDevice) MTU() (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MTU")
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MTU indicates an expected call of MTU.
func (mr *MockDeviceMockRecorder) MTU() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MTU", reflect.TypeOf((*MockDevice)(nil).MTU))
}
// Name mocks base method.
func (m *MockDevice) Name() (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Name")
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Name indicates an expected call of Name.
func (mr *MockDeviceMockRecorder) Name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockDevice)(nil).Name))
}
// Read mocks base method.
func (m *MockDevice) Read(arg0 [][]byte, arg1 []int, arg2 int) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Read", arg0, arg1, arg2)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read.
func (mr *MockDeviceMockRecorder) Read(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockDevice)(nil).Read), arg0, arg1, arg2)
}
// Write mocks base method.
func (m *MockDevice) Write(arg0 [][]byte, arg1 int) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Write", arg0, arg1)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Write indicates an expected call of Write.
func (mr *MockDeviceMockRecorder) Write(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockDevice)(nil).Write), arg0, arg1)
}

View File

@@ -22,6 +22,7 @@ type tunDevice struct {
name string
device *device.Device
iceBind *bind.ICEBind
wrapper *DeviceWrapper
}
func newTunDevice(address WGAddress, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice {
@@ -49,9 +50,10 @@ func (t *tunDevice) Create() error {
return err
}
t.name = name
t.wrapper = newDeviceWrapper(tunDevice)
log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(tunDevice, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
//t.device.DisableSomeRoamingForBrokenMobileSemantics()

View File

@@ -23,6 +23,7 @@ type tunDevice struct {
netInterface NetInterface
iceBind *bind.ICEBind
uapi net.Listener
wrapper *DeviceWrapper
close chan struct{}
}
@@ -90,9 +91,14 @@ func (c *tunDevice) createWithUserspace() (NetInterface, error) {
if err != nil {
return nil, err
}
c.wrapper = newDeviceWrapper(tunIface)
// We need to create a wireguard-go device and listen to configuration requests
tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
tunDev := device.NewDevice(
c.wrapper,
c.iceBind,
device.NewLogger(device.LogLevelSilent, "[netbird] "),
)
err = tunDev.Up()
if err != nil {
_ = tunIface.Close()

View File

@@ -23,6 +23,7 @@ type tunDevice struct {
iceBind *bind.ICEBind
mtu int
uapi net.Listener
wrapper *DeviceWrapper
close chan struct{}
}
@@ -52,6 +53,8 @@ func (c *tunDevice) createWithUserspace() (NetInterface, error) {
if err != nil {
return nil, err
}
c.wrapper = newDeviceWrapper(tunIface)
// We need to create a wireguard-go device and listen to configuration requests
tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
err = tunDev.Up()