From cb12e2da21ff6c5e1b5b9e03c6b34d7c085c97ff Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Thu, 4 May 2023 12:28:32 +0200 Subject: [PATCH] Correct sharedsock BPF fields (#835) --- client/internal/engine.go | 2 +- sharedsock/filter_linux.go | 30 +++++++++++++++--------------- sharedsock/filter_nolinux.go | 4 ++-- sharedsock/sock_linux.go | 3 ++- sharedsock/sock_linux_test.go | 8 ++++---- 5 files changed, 24 insertions(+), 23 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 0f0750c65..3011495db 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -210,7 +210,7 @@ func (e *Engine) Start() error { e.udpMux = udpMux log.Infof("using userspace bind mode %s", udpMux.LocalAddr().String()) } else { - rawSock, err := sharedsock.Listen(e.config.WgPort, sharedsock.NewSTUNFilter()) + rawSock, err := sharedsock.Listen(e.config.WgPort, sharedsock.NewIncomingSTUNFilter()) if err != nil { return err } diff --git a/sharedsock/filter_linux.go b/sharedsock/filter_linux.go index a9903f03f..2dd3eaded 100644 --- a/sharedsock/filter_linux.go +++ b/sharedsock/filter_linux.go @@ -2,44 +2,44 @@ package sharedsock import "golang.org/x/net/bpf" -// STUNFilter implements BPFFilter by filtering on STUN packets. +// IncomingSTUNFilter implements BPFFilter and filters out anything but incoming STUN packets to a specified destination port. // Other packets (non STUN) will be forwarded to the process that own the port (e.g., WireGuard). -type STUNFilter struct { +type IncomingSTUNFilter struct { } -// NewSTUNFilter creates an instance of a STUNFilter -func NewSTUNFilter() BPFFilter { - return &STUNFilter{} +// NewIncomingSTUNFilter creates an instance of a IncomingSTUNFilter +func NewIncomingSTUNFilter() BPFFilter { + return &IncomingSTUNFilter{} } // GetInstructions returns raw BPF instructions for ipv4 and ipv6 that filter out anything but STUN packets -func (sf STUNFilter) GetInstructions(port uint32) (raw4 []bpf.RawInstruction, raw6 []bpf.RawInstruction, err error) { - raw4, err = rawInstructions(22, 32, port) +func (filter *IncomingSTUNFilter) GetInstructions(dstPort uint32) (raw4 []bpf.RawInstruction, raw6 []bpf.RawInstruction, err error) { + raw4, err = rawInstructions(22, 32, dstPort) if err != nil { return nil, nil, err } - raw6, err = rawInstructions(2, 12, port) + raw6, err = rawInstructions(2, 12, dstPort) if err != nil { return nil, nil, err } return raw4, raw6, nil } -func rawInstructions(portOff, cookieOff, port uint32) ([]bpf.RawInstruction, error) { +func rawInstructions(dstPortOff, cookieOff, dstPort uint32) ([]bpf.RawInstruction, error) { // UDP raw socket for ipv4 receives the rcvdPacket with IP headers // UDP raw socket for ipv6 receives the rcvdPacket with UDP headers instructions := []bpf.Instruction{ - // Load the source port from the UDP header (offset 22 for ipv4 and 2 for ipv6) - bpf.LoadAbsolute{Off: portOff, Size: 2}, - // Check if the source port is equal to the specified `port`. If not, skip the next 3 instructions. - bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: port, SkipTrue: 3}, + // Load the destination port from the UDP header (offset 22 for ipv4 and 2 for ipv6) + bpf.LoadAbsolute{Off: dstPortOff, Size: 2}, + // Check if the destination port is equal to the specified `dstPort`. If not, skip the next 3 instructions. + bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: dstPort, SkipTrue: 3}, // Load the 4-byte value (magic cookie) from the UDP payload (offset 32 for ipv4 and 12 for ipv6) bpf.LoadAbsolute{Off: cookieOff, Size: 4}, // Check if the loaded value is equal to the `magicCookie`. If not, skip the next instruction. bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: magicCookie, SkipTrue: 1}, - // If both the port and the magic cookie match, return a positive value (0xffffffff) + // If both the dstPort and the magic cookie match, return a positive value (0xffffffff) bpf.RetConstant{Val: 0xffffffff}, - // If either the port or the magic cookie doesn't match, return 0 + // If either the dstPort or the magic cookie doesn't match, return 0 bpf.RetConstant{Val: 0}, } diff --git a/sharedsock/filter_nolinux.go b/sharedsock/filter_nolinux.go index 3ae648ade..bf1f22d88 100644 --- a/sharedsock/filter_nolinux.go +++ b/sharedsock/filter_nolinux.go @@ -2,7 +2,7 @@ package sharedsock -// NewSTUNFilter is a noop method just because we do not support BPF filters on other platforms than Linux -func NewSTUNFilter() BPFFilter { +// NewIncomingSTUNFilter is a noop method just because we do not support BPF filters on other platforms than Linux +func NewIncomingSTUNFilter() BPFFilter { return nil } diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index f6501928a..5d2b5a528 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -27,7 +27,8 @@ import ( var ErrSharedSockStopped = fmt.Errorf("shared socked stopped") // SharedSocket is a net.PacketConn that initiates two raw sockets (ipv4 and ipv6) and listens to UDP packets filtered -// by BPF instructions (e.g., STUNFilter that checks and sends only STUN packets to the listeners (ReadFrom)). +// by BPF instructions (e.g., IncomingSTUNFilter that checks and sends only STUN packets to the listeners (ReadFrom)). +// It is meant to be used when sharing a port with some other process. type SharedSocket struct { ctx context.Context conn4 *socket.Conn diff --git a/sharedsock/sock_linux_test.go b/sharedsock/sock_linux_test.go index 8e774cafc..40e880572 100644 --- a/sharedsock/sock_linux_test.go +++ b/sharedsock/sock_linux_test.go @@ -21,7 +21,7 @@ func TestShouldReadSTUNOnReadFrom(t *testing.T) { // create raw socket on a port testingPort := 51821 - rawSock, err := Listen(testingPort, NewSTUNFilter()) + rawSock, err := Listen(testingPort, NewIncomingSTUNFilter()) require.NoError(t, err, "received an error while creating STUN listener, error: %s", err) err = rawSock.SetReadDeadline(time.Now().Add(3 * time.Second)) require.NoError(t, err, "unable to set deadline, error: %s", err) @@ -76,7 +76,7 @@ func TestShouldReadSTUNOnReadFrom(t *testing.T) { func TestShouldNotReadNonSTUNPackets(t *testing.T) { testingPort := 39439 - rawSock, err := Listen(testingPort, NewSTUNFilter()) + rawSock, err := Listen(testingPort, NewIncomingSTUNFilter()) require.NoError(t, err, "received an error while creating STUN listener, error: %s", err) defer rawSock.Close() @@ -110,7 +110,7 @@ func TestWriteTo(t *testing.T) { defer udpListener.Close() testingPort := 39440 - rawSock, err := Listen(testingPort, NewSTUNFilter()) + rawSock, err := Listen(testingPort, NewIncomingSTUNFilter()) require.NoError(t, err, "received an error while creating STUN listener, error: %s", err) defer rawSock.Close() @@ -144,7 +144,7 @@ func TestWriteTo(t *testing.T) { } func TestSharedSocket_Close(t *testing.T) { - rawSock, err := Listen(39440, NewSTUNFilter()) + rawSock, err := Listen(39440, NewIncomingSTUNFilter()) require.NoError(t, err, "received an error while creating STUN listener, error: %s", err) errGrp := errgroup.Group{}