Add packet capture to debug bundle and CLI

This commit is contained in:
Viktor Liu
2026-04-15 07:26:13 +02:00
parent e804a705b7
commit e58c29d4f9
44 changed files with 4327 additions and 238 deletions

View File

@@ -0,0 +1,193 @@
package capture
import (
"encoding/binary"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// htons converts a uint16 from host to network (big-endian) byte order.
func htons(v uint16) uint16 {
var buf [2]byte
binary.BigEndian.PutUint16(buf[:], v)
return binary.NativeEndian.Uint16(buf[:])
}
// AFPacketCapture reads raw packets from a network interface using an
// AF_PACKET socket. This is the kernel-mode fallback when FilteredDevice is
// not available (kernel WireGuard). Linux only.
//
// It implements device.PacketCapture so it can be set on a Session, but it
// drives its own read loop rather than being called from FilteredDevice.
// Call Start to begin and Stop to end.
type AFPacketCapture struct {
ifaceName string
sess *Session
fd int
mu sync.Mutex
stopped chan struct{}
started atomic.Bool
closed atomic.Bool
}
// NewAFPacketCapture creates a capture bound to the given interface.
// The session receives packets via Offer.
func NewAFPacketCapture(ifaceName string, sess *Session) *AFPacketCapture {
return &AFPacketCapture{
ifaceName: ifaceName,
sess: sess,
fd: -1,
stopped: make(chan struct{}),
}
}
// Start opens the AF_PACKET socket and begins reading packets.
// Packets are fed to the session via Offer. Returns immediately;
// the read loop runs in a goroutine.
func (c *AFPacketCapture) Start() error {
if c.sess == nil {
return errors.New("nil capture session")
}
if c.started.Load() {
return errors.New("capture already started")
}
iface, err := net.InterfaceByName(c.ifaceName)
if err != nil {
return fmt.Errorf("interface %s: %w", c.ifaceName, err)
}
fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_DGRAM|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, int(htons(unix.ETH_P_ALL)))
if err != nil {
return fmt.Errorf("create AF_PACKET socket: %w", err)
}
addr := &unix.SockaddrLinklayer{
Protocol: htons(unix.ETH_P_ALL),
Ifindex: iface.Index,
}
if err := unix.Bind(fd, addr); err != nil {
unix.Close(fd)
return fmt.Errorf("bind to %s: %w", c.ifaceName, err)
}
c.mu.Lock()
c.fd = fd
c.mu.Unlock()
c.started.Store(true)
go c.readLoop(fd)
return nil
}
// Stop closes the socket and waits for the read loop to exit. Idempotent.
func (c *AFPacketCapture) Stop() {
if !c.closed.CompareAndSwap(false, true) {
if c.started.Load() {
<-c.stopped
}
return
}
c.mu.Lock()
fd := c.fd
c.fd = -1
c.mu.Unlock()
if fd >= 0 {
unix.Close(fd)
}
if c.started.Load() {
<-c.stopped
}
}
func (c *AFPacketCapture) readLoop(fd int) {
defer close(c.stopped)
buf := make([]byte, 65536)
pollFds := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN}}
for {
if c.closed.Load() {
return
}
ok, err := c.pollOnce(pollFds)
if err != nil {
return
}
if !ok {
continue
}
c.recvAndOffer(fd, buf)
}
}
// pollOnce waits for data on the fd. Returns true if data is ready, false for timeout/retry.
// Returns an error to signal the loop should exit.
func (c *AFPacketCapture) pollOnce(pollFds []unix.PollFd) (bool, error) {
n, err := unix.Poll(pollFds, 200)
if err != nil {
if errors.Is(err, unix.EINTR) {
return false, nil
}
if c.closed.Load() {
return false, errors.New("closed")
}
log.Debugf("af_packet poll: %v", err)
return false, err
}
if n == 0 {
return false, nil
}
if pollFds[0].Revents&(unix.POLLERR|unix.POLLHUP|unix.POLLNVAL) != 0 {
return false, errors.New("fd error")
}
return true, nil
}
func (c *AFPacketCapture) recvAndOffer(fd int, buf []byte) {
nr, from, err := unix.Recvfrom(fd, buf, 0)
if err != nil {
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) {
return
}
if !c.closed.Load() {
log.Debugf("af_packet recvfrom: %v", err)
}
return
}
if nr < 1 {
return
}
ver := buf[0] >> 4
if ver != 4 && ver != 6 {
return
}
// The kernel sets Pkttype on AF_PACKET sockets:
// PACKET_HOST(0) = addressed to us (inbound)
// PACKET_OUTGOING(4) = sent by us (outbound)
outbound := false
if sa, ok := from.(*unix.SockaddrLinklayer); ok {
outbound = sa.Pkttype == unix.PACKET_OUTGOING
}
c.sess.Offer(buf[:nr], outbound)
}
// Offer satisfies device.PacketCapture but is unused: the AFPacketCapture
// drives its own read loop. This exists only so the type signature is
// compatible if someone tries to set it as a PacketCapture.
func (c *AFPacketCapture) Offer([]byte, bool) {
// unused: AFPacketCapture drives its own read loop
}

View File

@@ -0,0 +1,26 @@
//go:build !linux
package capture
import "errors"
// AFPacketCapture is not available on this platform.
type AFPacketCapture struct{}
// NewAFPacketCapture returns nil on non-Linux platforms.
func NewAFPacketCapture(string, *Session) *AFPacketCapture { return nil }
// Start returns an error on non-Linux platforms.
func (c *AFPacketCapture) Start() error {
return errors.New("AF_PACKET capture is only supported on Linux")
}
// Stop is a no-op on non-Linux platforms.
func (c *AFPacketCapture) Stop() {
// no-op on non-Linux platforms
}
// Offer is a no-op on non-Linux platforms.
func (c *AFPacketCapture) Offer([]byte, bool) {
// no-op on non-Linux platforms
}

59
util/capture/capture.go Normal file
View File

@@ -0,0 +1,59 @@
// Package capture provides userspace packet capture in pcap format.
//
// It taps decrypted WireGuard packets flowing through the FilteredDevice and
// writes them as pcap (readable by tcpdump, tshark, Wireshark) or as
// human-readable one-line-per-packet text.
package capture
import "io"
// Direction indicates whether a packet is entering or leaving the host.
type Direction uint8
const (
// Inbound is a packet arriving from the network (FilteredDevice.Write path).
Inbound Direction = iota
// Outbound is a packet leaving the host (FilteredDevice.Read path).
Outbound
)
// String returns "IN" or "OUT".
func (d Direction) String() string {
if d == Outbound {
return "OUT"
}
return "IN"
}
const (
protoICMP = 1
protoTCP = 6
protoUDP = 17
protoICMPv6 = 58
)
// Options configures a capture session.
type Options struct {
// Output receives pcap-formatted data. Nil disables pcap output.
Output io.Writer
// TextOutput receives human-readable packet summaries. Nil disables text output.
TextOutput io.Writer
// Matcher selects which packets to capture. Nil captures all.
// Use ParseFilter("host 10.0.0.1 and tcp") or &Filter{...}.
Matcher Matcher
// Verbose adds seq/ack, TTL, window, total length to text output.
Verbose bool
// ASCII dumps transport payload as printable ASCII after each packet line.
ASCII bool
// SnapLen is the maximum bytes captured per packet. 0 means 65535.
SnapLen uint32
// BufSize is the internal channel buffer size. 0 means 256.
BufSize int
}
// Stats reports capture session counters.
type Stats struct {
Packets int64
Bytes int64
Dropped int64
}

528
util/capture/filter.go Normal file
View File

@@ -0,0 +1,528 @@
package capture
import (
"encoding/binary"
"fmt"
"net/netip"
"strconv"
"strings"
)
// Matcher tests whether a raw packet should be captured.
type Matcher interface {
Match(data []byte) bool
}
// Filter selects packets by flat AND'd criteria. Useful for structured APIs
// (query params, proto fields). Implements Matcher.
type Filter struct {
SrcIP netip.Addr
DstIP netip.Addr
Host netip.Addr
SrcPort uint16
DstPort uint16
Port uint16
Proto uint8
}
// IsEmpty returns true if the filter has no criteria set.
func (f *Filter) IsEmpty() bool {
return !f.SrcIP.IsValid() && !f.DstIP.IsValid() && !f.Host.IsValid() &&
f.SrcPort == 0 && f.DstPort == 0 && f.Port == 0 && f.Proto == 0
}
// Match implements Matcher. All non-zero fields must match (AND).
func (f *Filter) Match(data []byte) bool {
if f.IsEmpty() {
return true
}
info, ok := parsePacketInfo(data)
if !ok {
return false
}
if f.Host.IsValid() && info.srcIP != f.Host && info.dstIP != f.Host {
return false
}
if f.SrcIP.IsValid() && info.srcIP != f.SrcIP {
return false
}
if f.DstIP.IsValid() && info.dstIP != f.DstIP {
return false
}
if f.Proto != 0 && info.proto != f.Proto {
return false
}
if f.Port != 0 && info.srcPort != f.Port && info.dstPort != f.Port {
return false
}
if f.SrcPort != 0 && info.srcPort != f.SrcPort {
return false
}
if f.DstPort != 0 && info.dstPort != f.DstPort {
return false
}
return true
}
// exprNode evaluates a filter condition against pre-parsed packet info.
type exprNode func(info *packetInfo) bool
// exprMatcher wraps an expression tree. Parses the packet once, then walks the tree.
type exprMatcher struct {
root exprNode
}
func (m *exprMatcher) Match(data []byte) bool {
info, ok := parsePacketInfo(data)
if !ok {
return false
}
return m.root(&info)
}
func nodeAnd(a, b exprNode) exprNode {
return func(info *packetInfo) bool { return a(info) && b(info) }
}
func nodeOr(a, b exprNode) exprNode {
return func(info *packetInfo) bool { return a(info) || b(info) }
}
func nodeNot(n exprNode) exprNode {
return func(info *packetInfo) bool { return !n(info) }
}
func nodeHost(addr netip.Addr) exprNode {
return func(info *packetInfo) bool { return info.srcIP == addr || info.dstIP == addr }
}
func nodeSrcHost(addr netip.Addr) exprNode {
return func(info *packetInfo) bool { return info.srcIP == addr }
}
func nodeDstHost(addr netip.Addr) exprNode {
return func(info *packetInfo) bool { return info.dstIP == addr }
}
func nodePort(port uint16) exprNode {
return func(info *packetInfo) bool { return info.srcPort == port || info.dstPort == port }
}
func nodeSrcPort(port uint16) exprNode {
return func(info *packetInfo) bool { return info.srcPort == port }
}
func nodeDstPort(port uint16) exprNode {
return func(info *packetInfo) bool { return info.dstPort == port }
}
func nodeProto(proto uint8) exprNode {
return func(info *packetInfo) bool { return info.proto == proto }
}
func nodeFamily(family uint8) exprNode {
return func(info *packetInfo) bool { return info.family == family }
}
func nodeNet(prefix netip.Prefix) exprNode {
return func(info *packetInfo) bool { return prefix.Contains(info.srcIP) || prefix.Contains(info.dstIP) }
}
func nodeSrcNet(prefix netip.Prefix) exprNode {
return func(info *packetInfo) bool { return prefix.Contains(info.srcIP) }
}
func nodeDstNet(prefix netip.Prefix) exprNode {
return func(info *packetInfo) bool { return prefix.Contains(info.dstIP) }
}
// packetInfo holds parsed header fields for filtering and display.
type packetInfo struct {
family uint8
srcIP netip.Addr
dstIP netip.Addr
proto uint8
srcPort uint16
dstPort uint16
hdrLen int
}
func parsePacketInfo(data []byte) (packetInfo, bool) {
if len(data) < 1 {
return packetInfo{}, false
}
switch data[0] >> 4 {
case 4:
return parseIPv4Info(data)
case 6:
return parseIPv6Info(data)
default:
return packetInfo{}, false
}
}
func parseIPv4Info(data []byte) (packetInfo, bool) {
if len(data) < 20 {
return packetInfo{}, false
}
ihl := int(data[0]&0x0f) * 4
if ihl < 20 || len(data) < ihl {
return packetInfo{}, false
}
info := packetInfo{
family: 4,
srcIP: netip.AddrFrom4([4]byte{data[12], data[13], data[14], data[15]}),
dstIP: netip.AddrFrom4([4]byte{data[16], data[17], data[18], data[19]}),
proto: data[9],
hdrLen: ihl,
}
if (info.proto == protoTCP || info.proto == protoUDP) && len(data) >= ihl+4 {
info.srcPort = binary.BigEndian.Uint16(data[ihl:])
info.dstPort = binary.BigEndian.Uint16(data[ihl+2:])
}
return info, true
}
// parseIPv6Info parses the fixed IPv6 header. It reads the Next Header field
// directly, so packets with extension headers (hop-by-hop, routing, fragment,
// etc.) will report the extension type as the protocol rather than the final
// transport protocol. This is acceptable for a debug capture tool.
func parseIPv6Info(data []byte) (packetInfo, bool) {
if len(data) < 40 {
return packetInfo{}, false
}
var src, dst [16]byte
copy(src[:], data[8:24])
copy(dst[:], data[24:40])
info := packetInfo{
family: 6,
srcIP: netip.AddrFrom16(src),
dstIP: netip.AddrFrom16(dst),
proto: data[6],
hdrLen: 40,
}
if (info.proto == protoTCP || info.proto == protoUDP) && len(data) >= 44 {
info.srcPort = binary.BigEndian.Uint16(data[40:])
info.dstPort = binary.BigEndian.Uint16(data[42:])
}
return info, true
}
// ParseFilter parses a BPF-like filter expression and returns a Matcher.
// Returns nil Matcher for an empty expression (match all).
//
// Grammar (mirrors common tcpdump BPF syntax):
//
// orExpr = andExpr ("or" andExpr)*
// andExpr = unary ("and" unary)*
// unary = "not" unary | "(" orExpr ")" | term
//
// term = "host" IP | "src" target | "dst" target
// | "port" NUM | "net" PREFIX
// | "tcp" | "udp" | "icmp" | "icmp6"
// | "ip" | "ip6" | "proto" NUM
// target = "host" IP | "port" NUM | "net" PREFIX | IP
//
// Examples:
//
// host 10.0.0.1 and tcp port 443
// not port 22
// (host 10.0.0.1 or host 10.0.0.2) and tcp
// ip6 and icmp6
// net 10.0.0.0/24
// src host 10.0.0.1 or dst port 80
func ParseFilter(expr string) (Matcher, error) {
tokens := tokenize(expr)
if len(tokens) == 0 {
return nil, nil //nolint:nilnil // nil Matcher means "match all"
}
p := &parser{tokens: tokens}
node, err := p.parseOr()
if err != nil {
return nil, err
}
if p.pos < len(p.tokens) {
return nil, fmt.Errorf("unexpected token %q at position %d", p.tokens[p.pos], p.pos)
}
return &exprMatcher{root: node}, nil
}
func tokenize(expr string) []string {
expr = strings.TrimSpace(expr)
if expr == "" {
return nil
}
// Split on whitespace but keep parens as separate tokens.
var tokens []string
for _, field := range strings.Fields(expr) {
tokens = append(tokens, splitParens(field)...)
}
return tokens
}
// splitParens splits "(foo)" into "(", "foo", ")".
func splitParens(s string) []string {
var out []string
for strings.HasPrefix(s, "(") {
out = append(out, "(")
s = s[1:]
}
var trail []string
for strings.HasSuffix(s, ")") {
trail = append(trail, ")")
s = s[:len(s)-1]
}
if s != "" {
out = append(out, s)
}
out = append(out, trail...)
return out
}
type parser struct {
tokens []string
pos int
}
func (p *parser) peek() string {
if p.pos >= len(p.tokens) {
return ""
}
return strings.ToLower(p.tokens[p.pos])
}
func (p *parser) next() string {
tok := p.peek()
if tok != "" {
p.pos++
}
return tok
}
func (p *parser) expect(tok string) error {
got := p.next()
if got != tok {
return fmt.Errorf("expected %q, got %q", tok, got)
}
return nil
}
func (p *parser) parseOr() (exprNode, error) {
left, err := p.parseAnd()
if err != nil {
return nil, err
}
for p.peek() == "or" {
p.next()
right, err := p.parseAnd()
if err != nil {
return nil, err
}
left = nodeOr(left, right)
}
return left, nil
}
func (p *parser) parseAnd() (exprNode, error) {
left, err := p.parseUnary()
if err != nil {
return nil, err
}
for {
tok := p.peek()
if tok == "and" {
p.next()
right, err := p.parseUnary()
if err != nil {
return nil, err
}
left = nodeAnd(left, right)
continue
}
// Implicit AND: two atoms without "and" between them.
// Only if the next token starts an atom (not "or", ")", or EOF).
if tok != "" && tok != "or" && tok != ")" {
right, err := p.parseUnary()
if err != nil {
return nil, err
}
left = nodeAnd(left, right)
continue
}
break
}
return left, nil
}
func (p *parser) parseUnary() (exprNode, error) {
switch p.peek() {
case "not":
p.next()
inner, err := p.parseUnary()
if err != nil {
return nil, err
}
return nodeNot(inner), nil
case "(":
p.next()
inner, err := p.parseOr()
if err != nil {
return nil, err
}
if err := p.expect(")"); err != nil {
return nil, fmt.Errorf("unclosed parenthesis")
}
return inner, nil
default:
return p.parseAtom()
}
}
func (p *parser) parseAtom() (exprNode, error) {
tok := p.next()
if tok == "" {
return nil, fmt.Errorf("unexpected end of expression")
}
switch tok {
case "host":
addr, err := p.parseAddr()
if err != nil {
return nil, fmt.Errorf("host: %w", err)
}
return nodeHost(addr), nil
case "port":
port, err := p.parsePort()
if err != nil {
return nil, fmt.Errorf("port: %w", err)
}
return nodePort(port), nil
case "net":
prefix, err := p.parsePrefix()
if err != nil {
return nil, fmt.Errorf("net: %w", err)
}
return nodeNet(prefix), nil
case "src":
return p.parseDirTarget(true)
case "dst":
return p.parseDirTarget(false)
case "tcp":
return nodeProto(protoTCP), nil
case "udp":
return nodeProto(protoUDP), nil
case "icmp":
return nodeProto(protoICMP), nil
case "icmp6":
return nodeProto(protoICMPv6), nil
case "ip":
return nodeFamily(4), nil
case "ip6":
return nodeFamily(6), nil
case "proto":
raw := p.next()
if raw == "" {
return nil, fmt.Errorf("proto: missing number")
}
n, err := strconv.Atoi(raw)
if err != nil || n < 0 || n > 255 {
return nil, fmt.Errorf("proto: invalid number %q", raw)
}
return nodeProto(uint8(n)), nil
default:
return nil, fmt.Errorf("unknown filter keyword %q", tok)
}
}
func (p *parser) parseDirTarget(isSrc bool) (exprNode, error) {
tok := p.peek()
switch tok {
case "host":
p.next()
addr, err := p.parseAddr()
if err != nil {
return nil, err
}
if isSrc {
return nodeSrcHost(addr), nil
}
return nodeDstHost(addr), nil
case "port":
p.next()
port, err := p.parsePort()
if err != nil {
return nil, err
}
if isSrc {
return nodeSrcPort(port), nil
}
return nodeDstPort(port), nil
case "net":
p.next()
prefix, err := p.parsePrefix()
if err != nil {
return nil, err
}
if isSrc {
return nodeSrcNet(prefix), nil
}
return nodeDstNet(prefix), nil
default:
// Try as bare IP: "src 10.0.0.1"
addr, err := p.parseAddr()
if err != nil {
return nil, fmt.Errorf("expected host, port, net, or IP after src/dst, got %q", tok)
}
if isSrc {
return nodeSrcHost(addr), nil
}
return nodeDstHost(addr), nil
}
}
func (p *parser) parseAddr() (netip.Addr, error) {
raw := p.next()
if raw == "" {
return netip.Addr{}, fmt.Errorf("missing IP address")
}
addr, err := netip.ParseAddr(raw)
if err != nil {
return netip.Addr{}, fmt.Errorf("invalid IP %q", raw)
}
return addr.Unmap(), nil
}
func (p *parser) parsePort() (uint16, error) {
raw := p.next()
if raw == "" {
return 0, fmt.Errorf("missing port number")
}
n, err := strconv.Atoi(raw)
if err != nil || n < 1 || n > 65535 {
return 0, fmt.Errorf("invalid port %q", raw)
}
return uint16(n), nil
}
func (p *parser) parsePrefix() (netip.Prefix, error) {
raw := p.next()
if raw == "" {
return netip.Prefix{}, fmt.Errorf("missing network prefix")
}
prefix, err := netip.ParsePrefix(raw)
if err != nil {
return netip.Prefix{}, fmt.Errorf("invalid prefix %q", raw)
}
return prefix, nil
}

263
util/capture/filter_test.go Normal file
View File

@@ -0,0 +1,263 @@
package capture
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// buildIPv4Packet creates a minimal IPv4+TCP/UDP packet for filter testing.
func buildIPv4Packet(t *testing.T, srcIP, dstIP netip.Addr, proto uint8, srcPort, dstPort uint16) []byte {
t.Helper()
hdrLen := 20
pkt := make([]byte, hdrLen+20)
pkt[0] = 0x45
pkt[9] = proto
src := srcIP.As4()
dst := dstIP.As4()
copy(pkt[12:16], src[:])
copy(pkt[16:20], dst[:])
pkt[20] = byte(srcPort >> 8)
pkt[21] = byte(srcPort)
pkt[22] = byte(dstPort >> 8)
pkt[23] = byte(dstPort)
return pkt
}
// buildIPv6Packet creates a minimal IPv6+TCP/UDP packet for filter testing.
func buildIPv6Packet(t *testing.T, srcIP, dstIP netip.Addr, proto uint8, srcPort, dstPort uint16) []byte {
t.Helper()
pkt := make([]byte, 44) // 40 header + 4 ports
pkt[0] = 0x60 // version 6
pkt[6] = proto // next header
src := srcIP.As16()
dst := dstIP.As16()
copy(pkt[8:24], src[:])
copy(pkt[24:40], dst[:])
pkt[40] = byte(srcPort >> 8)
pkt[41] = byte(srcPort)
pkt[42] = byte(dstPort >> 8)
pkt[43] = byte(dstPort)
return pkt
}
// ---- Filter struct tests ----
func TestFilter_Empty(t *testing.T) {
f := Filter{}
assert.True(t, f.IsEmpty())
assert.True(t, f.Match(buildIPv4Packet(t,
netip.MustParseAddr("10.0.0.1"),
netip.MustParseAddr("10.0.0.2"),
protoTCP, 12345, 443)))
}
func TestFilter_Host(t *testing.T) {
f := Filter{Host: netip.MustParseAddr("10.0.0.1")}
assert.True(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), protoTCP, 1234, 80)))
assert.True(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.1"), protoTCP, 1234, 80)))
assert.False(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.3"), protoTCP, 1234, 80)))
}
func TestFilter_InvalidPacket(t *testing.T) {
f := Filter{Host: netip.MustParseAddr("10.0.0.1")}
assert.False(t, f.Match(nil))
assert.False(t, f.Match([]byte{}))
assert.False(t, f.Match([]byte{0x00}))
}
func TestParsePacketInfo_IPv4(t *testing.T) {
pkt := buildIPv4Packet(t, netip.MustParseAddr("192.168.1.1"), netip.MustParseAddr("10.0.0.1"), protoTCP, 54321, 80)
info, ok := parsePacketInfo(pkt)
require.True(t, ok)
assert.Equal(t, uint8(4), info.family)
assert.Equal(t, netip.MustParseAddr("192.168.1.1"), info.srcIP)
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), info.dstIP)
assert.Equal(t, uint8(protoTCP), info.proto)
assert.Equal(t, uint16(54321), info.srcPort)
assert.Equal(t, uint16(80), info.dstPort)
}
func TestParsePacketInfo_IPv6(t *testing.T) {
pkt := buildIPv6Packet(t, netip.MustParseAddr("fd00::1"), netip.MustParseAddr("fd00::2"), protoUDP, 1234, 53)
info, ok := parsePacketInfo(pkt)
require.True(t, ok)
assert.Equal(t, uint8(6), info.family)
assert.Equal(t, netip.MustParseAddr("fd00::1"), info.srcIP)
assert.Equal(t, netip.MustParseAddr("fd00::2"), info.dstIP)
assert.Equal(t, uint8(protoUDP), info.proto)
assert.Equal(t, uint16(1234), info.srcPort)
assert.Equal(t, uint16(53), info.dstPort)
}
// ---- ParseFilter expression tests ----
func matchV4(t *testing.T, m Matcher, srcIP, dstIP string, proto uint8, srcPort, dstPort uint16) bool {
t.Helper()
return m.Match(buildIPv4Packet(t, netip.MustParseAddr(srcIP), netip.MustParseAddr(dstIP), proto, srcPort, dstPort))
}
func matchV6(t *testing.T, m Matcher, srcIP, dstIP string, proto uint8, srcPort, dstPort uint16) bool {
t.Helper()
return m.Match(buildIPv6Packet(t, netip.MustParseAddr(srcIP), netip.MustParseAddr(dstIP), proto, srcPort, dstPort))
}
func TestParseFilter_Empty(t *testing.T) {
m, err := ParseFilter("")
require.NoError(t, err)
assert.Nil(t, m, "empty expression should return nil matcher")
}
func TestParseFilter_Atoms(t *testing.T) {
tests := []struct {
expr string
match bool
}{
{"tcp", true},
{"udp", false},
{"host 10.0.0.1", true},
{"host 10.0.0.99", false},
{"port 443", true},
{"port 80", false},
{"src host 10.0.0.1", true},
{"dst host 10.0.0.2", true},
{"dst host 10.0.0.1", false},
{"src port 12345", true},
{"dst port 443", true},
{"dst port 80", false},
{"proto 6", true},
{"proto 17", false},
}
pkt := buildIPv4Packet(t, netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), protoTCP, 12345, 443)
for _, tt := range tests {
t.Run(tt.expr, func(t *testing.T) {
m, err := ParseFilter(tt.expr)
require.NoError(t, err)
assert.Equal(t, tt.match, m.Match(pkt))
})
}
}
func TestParseFilter_And(t *testing.T) {
m, err := ParseFilter("host 10.0.0.1 and tcp port 443")
require.NoError(t, err)
assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 443))
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoUDP, 55555, 443), "wrong proto")
assert.False(t, matchV4(t, m, "10.0.0.3", "10.0.0.2", protoTCP, 55555, 443), "wrong host")
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 80), "wrong port")
}
func TestParseFilter_ImplicitAnd(t *testing.T) {
// "tcp port 443" = implicit AND between tcp and port 443
m, err := ParseFilter("tcp port 443")
require.NoError(t, err)
assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443))
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoUDP, 1, 443))
}
func TestParseFilter_Or(t *testing.T) {
m, err := ParseFilter("port 80 or port 443")
require.NoError(t, err)
assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 80))
assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 443))
assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 8080))
}
func TestParseFilter_Not(t *testing.T) {
m, err := ParseFilter("not port 22")
require.NoError(t, err)
assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443))
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 22))
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 22, 80))
}
func TestParseFilter_Parens(t *testing.T) {
m, err := ParseFilter("(port 80 or port 443) and tcp")
require.NoError(t, err)
assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 443))
assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoUDP, 1, 443), "wrong proto")
assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 8080), "wrong port")
}
func TestParseFilter_Family(t *testing.T) {
mV4, err := ParseFilter("ip")
require.NoError(t, err)
assert.True(t, matchV4(t, mV4, "10.0.0.1", "10.0.0.2", protoTCP, 1, 80))
assert.False(t, matchV6(t, mV4, "fd00::1", "fd00::2", protoTCP, 1, 80))
mV6, err := ParseFilter("ip6")
require.NoError(t, err)
assert.False(t, matchV4(t, mV6, "10.0.0.1", "10.0.0.2", protoTCP, 1, 80))
assert.True(t, matchV6(t, mV6, "fd00::1", "fd00::2", protoTCP, 1, 80))
}
func TestParseFilter_Net(t *testing.T) {
m, err := ParseFilter("net 10.0.0.0/24")
require.NoError(t, err)
assert.True(t, matchV4(t, m, "10.0.0.1", "192.168.1.1", protoTCP, 1, 80), "src in net")
assert.True(t, matchV4(t, m, "192.168.1.1", "10.0.0.200", protoTCP, 1, 80), "dst in net")
assert.False(t, matchV4(t, m, "10.0.1.1", "192.168.1.1", protoTCP, 1, 80), "neither in net")
}
func TestParseFilter_SrcDstNet(t *testing.T) {
m, err := ParseFilter("src net 10.0.0.0/8 and dst net 192.168.0.0/16")
require.NoError(t, err)
assert.True(t, matchV4(t, m, "10.1.2.3", "192.168.1.1", protoTCP, 1, 80))
assert.False(t, matchV4(t, m, "192.168.1.1", "10.1.2.3", protoTCP, 1, 80), "reversed")
}
func TestParseFilter_Complex(t *testing.T) {
// Real-world: capture HTTP(S) traffic to/from specific host, excluding SSH
m, err := ParseFilter("host 10.0.0.1 and (port 80 or port 443) and not port 22")
require.NoError(t, err)
assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 443))
assert.True(t, matchV4(t, m, "10.0.0.2", "10.0.0.1", protoTCP, 55555, 80))
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 22, 443), "port 22 excluded")
assert.False(t, matchV4(t, m, "10.0.0.3", "10.0.0.2", protoTCP, 55555, 443), "wrong host")
}
func TestParseFilter_IPv6Combined(t *testing.T) {
m, err := ParseFilter("ip6 and icmp6")
require.NoError(t, err)
assert.True(t, matchV6(t, m, "fd00::1", "fd00::2", protoICMPv6, 0, 0))
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoICMP, 0, 0), "wrong family")
assert.False(t, matchV6(t, m, "fd00::1", "fd00::2", protoTCP, 1, 80), "wrong proto")
}
func TestParseFilter_CaseInsensitive(t *testing.T) {
m, err := ParseFilter("HOST 10.0.0.1 AND TCP PORT 443")
require.NoError(t, err)
assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443))
}
func TestParseFilter_Errors(t *testing.T) {
bad := []string{
"badkeyword",
"host",
"port abc",
"port 99999",
"net invalid",
"(",
"(port 80",
"not",
"src",
}
for _, expr := range bad {
t.Run(expr, func(t *testing.T) {
_, err := ParseFilter(expr)
assert.Error(t, err, "should fail for %q", expr)
})
}
}

85
util/capture/pcap.go Normal file
View File

@@ -0,0 +1,85 @@
package capture
import (
"encoding/binary"
"io"
"time"
)
const (
pcapMagic = 0xa1b2c3d4
pcapVersionMaj = 2
pcapVersionMin = 4
// linkTypeRaw is LINKTYPE_RAW: raw IPv4/IPv6 packets without link-layer header.
linkTypeRaw = 101
defaultSnapLen = 65535
)
// PcapWriter writes packets in pcap format to an underlying writer.
// The global header is written lazily on the first WritePacket call so that
// the writer can be used with unbuffered io.Pipes without deadlocking.
// It is not safe for concurrent use; callers must serialize access.
type PcapWriter struct {
w io.Writer
snapLen uint32
headerWritten bool
}
// NewPcapWriter creates a pcap writer. The global header is deferred until the
// first WritePacket call.
func NewPcapWriter(w io.Writer, snapLen uint32) *PcapWriter {
if snapLen == 0 {
snapLen = defaultSnapLen
}
return &PcapWriter{w: w, snapLen: snapLen}
}
// writeGlobalHeader writes the 24-byte pcap file header.
func (pw *PcapWriter) writeGlobalHeader() error {
var hdr [24]byte
binary.LittleEndian.PutUint32(hdr[0:4], pcapMagic)
binary.LittleEndian.PutUint16(hdr[4:6], pcapVersionMaj)
binary.LittleEndian.PutUint16(hdr[6:8], pcapVersionMin)
binary.LittleEndian.PutUint32(hdr[16:20], pw.snapLen)
binary.LittleEndian.PutUint32(hdr[20:24], linkTypeRaw)
_, err := pw.w.Write(hdr[:])
return err
}
// WriteHeader writes the pcap global header. Safe to call multiple times.
func (pw *PcapWriter) WriteHeader() error {
if pw.headerWritten {
return nil
}
if err := pw.writeGlobalHeader(); err != nil {
return err
}
pw.headerWritten = true
return nil
}
// WritePacket writes a single packet record, preceded by the global header
// on the first call.
func (pw *PcapWriter) WritePacket(ts time.Time, data []byte) error {
if err := pw.WriteHeader(); err != nil {
return err
}
origLen := uint32(len(data))
if origLen > pw.snapLen {
data = data[:pw.snapLen]
}
var hdr [16]byte
binary.LittleEndian.PutUint32(hdr[0:4], uint32(ts.Unix()))
binary.LittleEndian.PutUint32(hdr[4:8], uint32(ts.Nanosecond()/1000))
binary.LittleEndian.PutUint32(hdr[8:12], uint32(len(data)))
binary.LittleEndian.PutUint32(hdr[12:16], origLen)
if _, err := pw.w.Write(hdr[:]); err != nil {
return err
}
_, err := pw.w.Write(data)
return err
}

68
util/capture/pcap_test.go Normal file
View File

@@ -0,0 +1,68 @@
package capture
import (
"bytes"
"encoding/binary"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPcapWriter_GlobalHeader(t *testing.T) {
var buf bytes.Buffer
pw := NewPcapWriter(&buf, 0)
// Header is lazy, so write a dummy packet to trigger it.
err := pw.WritePacket(time.Now(), []byte{0x45, 0, 0, 20, 0, 0, 0, 0, 64, 1, 0, 0, 10, 0, 0, 1, 10, 0, 0, 2})
require.NoError(t, err)
data := buf.Bytes()
require.GreaterOrEqual(t, len(data), 24, "should contain global header")
assert.Equal(t, uint32(pcapMagic), binary.LittleEndian.Uint32(data[0:4]), "magic number")
assert.Equal(t, uint16(pcapVersionMaj), binary.LittleEndian.Uint16(data[4:6]), "version major")
assert.Equal(t, uint16(pcapVersionMin), binary.LittleEndian.Uint16(data[6:8]), "version minor")
assert.Equal(t, uint32(defaultSnapLen), binary.LittleEndian.Uint32(data[16:20]), "snap length")
assert.Equal(t, uint32(linkTypeRaw), binary.LittleEndian.Uint32(data[20:24]), "link type")
}
func TestPcapWriter_WritePacket(t *testing.T) {
var buf bytes.Buffer
pw := NewPcapWriter(&buf, 100)
ts := time.Date(2025, 6, 15, 12, 30, 45, 123456000, time.UTC)
payload := make([]byte, 50)
for i := range payload {
payload[i] = byte(i)
}
err := pw.WritePacket(ts, payload)
require.NoError(t, err)
data := buf.Bytes()[24:] // skip global header
require.Len(t, data, 16+50, "packet header + payload")
assert.Equal(t, uint32(ts.Unix()), binary.LittleEndian.Uint32(data[0:4]), "timestamp seconds")
assert.Equal(t, uint32(123456), binary.LittleEndian.Uint32(data[4:8]), "timestamp microseconds")
assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[8:12]), "included length")
assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[12:16]), "original length")
assert.Equal(t, payload, data[16:], "packet data")
}
func TestPcapWriter_SnapLen(t *testing.T) {
var buf bytes.Buffer
pw := NewPcapWriter(&buf, 10)
ts := time.Now()
payload := make([]byte, 50)
err := pw.WritePacket(ts, payload)
require.NoError(t, err)
data := buf.Bytes()[24:]
assert.Equal(t, uint32(10), binary.LittleEndian.Uint32(data[8:12]), "included length should be truncated")
assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[12:16]), "original length preserved")
assert.Len(t, data[16:], 10, "only snap_len bytes written")
}

213
util/capture/session.go Normal file
View File

@@ -0,0 +1,213 @@
package capture
import (
"fmt"
"sync"
"sync/atomic"
"time"
)
const defaultBufSize = 256
type packetEntry struct {
ts time.Time
data []byte
dir Direction
}
// Session manages an active packet capture. Packets are offered via Offer,
// buffered in a channel, and written to configured sinks by a background
// goroutine. This keeps the hot path (FilteredDevice.Read/Write) non-blocking.
//
// The caller must call Stop when done to flush remaining packets and release
// resources.
type Session struct {
pcapW *PcapWriter
textW *TextWriter
matcher Matcher
snapLen uint32
flushFn func()
ch chan packetEntry
done chan struct{}
stopped chan struct{}
closeOnce sync.Once
closed atomic.Bool
packets atomic.Int64
bytes atomic.Int64
dropped atomic.Int64
started time.Time
}
// NewSession creates and starts a capture session. At least one of
// Options.Output or Options.TextOutput must be non-nil.
func NewSession(opts Options) (*Session, error) {
if opts.Output == nil && opts.TextOutput == nil {
return nil, fmt.Errorf("at least one output sink required")
}
snapLen := opts.SnapLen
if snapLen == 0 {
snapLen = defaultSnapLen
}
bufSize := opts.BufSize
if bufSize <= 0 {
bufSize = defaultBufSize
}
s := &Session{
matcher: opts.Matcher,
snapLen: snapLen,
ch: make(chan packetEntry, bufSize),
done: make(chan struct{}),
stopped: make(chan struct{}),
started: time.Now(),
}
if opts.Output != nil {
s.pcapW = NewPcapWriter(opts.Output, snapLen)
}
if opts.TextOutput != nil {
s.textW = NewTextWriter(opts.TextOutput, opts.Verbose, opts.ASCII)
}
s.flushFn = buildFlushFn(opts.Output, opts.TextOutput)
go s.run()
return s, nil
}
// Offer submits a packet for capture. It returns immediately and never blocks
// the caller. If the internal buffer is full the packet is dropped silently.
//
// outbound should be true for packets leaving the host (FilteredDevice.Read
// path) and false for packets arriving (FilteredDevice.Write path).
//
// Offer satisfies the device.PacketCapture interface.
func (s *Session) Offer(data []byte, outbound bool) {
if s.closed.Load() {
return
}
if s.matcher != nil && !s.matcher.Match(data) {
return
}
captureLen := len(data)
if s.snapLen > 0 && uint32(captureLen) > s.snapLen {
captureLen = int(s.snapLen)
}
copied := make([]byte, captureLen)
copy(copied, data)
dir := Inbound
if outbound {
dir = Outbound
}
select {
case s.ch <- packetEntry{ts: time.Now(), data: copied, dir: dir}:
s.packets.Add(1)
s.bytes.Add(int64(len(data)))
default:
s.dropped.Add(1)
}
}
// Stop signals the session to stop accepting packets, drains any buffered
// packets to the sinks, and waits for the writer goroutine to exit.
// It is safe to call multiple times.
func (s *Session) Stop() {
s.closeOnce.Do(func() {
s.closed.Store(true)
close(s.done)
})
<-s.stopped
}
// Done returns a channel that is closed when the session's writer goroutine
// has fully exited and all buffered packets have been flushed.
func (s *Session) Done() <-chan struct{} {
return s.stopped
}
// Stats returns current capture counters.
func (s *Session) Stats() Stats {
return Stats{
Packets: s.packets.Load(),
Bytes: s.bytes.Load(),
Dropped: s.dropped.Load(),
}
}
func (s *Session) run() {
defer close(s.stopped)
for {
select {
case pkt := <-s.ch:
s.write(pkt)
case <-s.done:
s.drain()
return
}
}
}
func (s *Session) drain() {
for {
select {
case pkt := <-s.ch:
s.write(pkt)
default:
return
}
}
}
func (s *Session) write(pkt packetEntry) {
if s.pcapW != nil {
// Best-effort: if the writer fails (broken pipe etc.), discard silently.
_ = s.pcapW.WritePacket(pkt.ts, pkt.data)
}
if s.textW != nil {
_ = s.textW.WritePacket(pkt.ts, pkt.data, pkt.dir)
}
s.flushFn()
}
// buildFlushFn returns a function that flushes all writers that support it.
// This covers http.Flusher and similar streaming writers.
func buildFlushFn(writers ...any) func() {
type flusher interface {
Flush()
}
var fns []func()
for _, w := range writers {
if w == nil {
continue
}
if f, ok := w.(flusher); ok {
fns = append(fns, f.Flush)
}
}
switch len(fns) {
case 0:
return func() {
// no writers to flush
}
case 1:
return fns[0]
default:
return func() {
for _, fn := range fns {
fn()
}
}
}
}

View File

@@ -0,0 +1,144 @@
package capture
import (
"bytes"
"encoding/binary"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSession_PcapOutput(t *testing.T) {
var buf bytes.Buffer
sess, err := NewSession(Options{
Output: &buf,
BufSize: 16,
})
require.NoError(t, err)
pkt := buildIPv4Packet(t,
netip.MustParseAddr("10.0.0.1"),
netip.MustParseAddr("10.0.0.2"),
protoTCP, 12345, 443)
sess.Offer(pkt, true)
sess.Stop()
data := buf.Bytes()
require.Greater(t, len(data), 24, "should have global header + at least one packet")
// Verify global header
assert.Equal(t, uint32(pcapMagic), binary.LittleEndian.Uint32(data[0:4]))
assert.Equal(t, uint32(linkTypeRaw), binary.LittleEndian.Uint32(data[20:24]))
// Verify packet record
pktData := data[24:]
inclLen := binary.LittleEndian.Uint32(pktData[8:12])
assert.Equal(t, uint32(len(pkt)), inclLen)
stats := sess.Stats()
assert.Equal(t, int64(1), stats.Packets)
assert.Equal(t, int64(len(pkt)), stats.Bytes)
assert.Equal(t, int64(0), stats.Dropped)
}
func TestSession_TextOutput(t *testing.T) {
var buf bytes.Buffer
sess, err := NewSession(Options{
TextOutput: &buf,
BufSize: 16,
})
require.NoError(t, err)
pkt := buildIPv4Packet(t,
netip.MustParseAddr("10.0.0.1"),
netip.MustParseAddr("10.0.0.2"),
protoTCP, 12345, 443)
sess.Offer(pkt, false)
sess.Stop()
output := buf.String()
assert.Contains(t, output, "TCP")
assert.Contains(t, output, "10.0.0.1")
assert.Contains(t, output, "10.0.0.2")
assert.Contains(t, output, "443")
assert.Contains(t, output, "[IN TCP]")
}
func TestSession_Filter(t *testing.T) {
var buf bytes.Buffer
sess, err := NewSession(Options{
Output: &buf,
Matcher: &Filter{Port: 443},
})
require.NoError(t, err)
pktMatch := buildIPv4Packet(t,
netip.MustParseAddr("10.0.0.1"),
netip.MustParseAddr("10.0.0.2"),
protoTCP, 12345, 443)
pktNoMatch := buildIPv4Packet(t,
netip.MustParseAddr("10.0.0.1"),
netip.MustParseAddr("10.0.0.2"),
protoTCP, 12345, 80)
sess.Offer(pktMatch, true)
sess.Offer(pktNoMatch, true)
sess.Stop()
stats := sess.Stats()
assert.Equal(t, int64(1), stats.Packets, "only matching packet should be captured")
}
func TestSession_StopIdempotent(t *testing.T) {
var buf bytes.Buffer
sess, err := NewSession(Options{Output: &buf})
require.NoError(t, err)
sess.Stop()
sess.Stop() // should not panic or deadlock
}
func TestSession_OfferAfterStop(t *testing.T) {
var buf bytes.Buffer
sess, err := NewSession(Options{Output: &buf})
require.NoError(t, err)
sess.Stop()
pkt := buildIPv4Packet(t,
netip.MustParseAddr("10.0.0.1"),
netip.MustParseAddr("10.0.0.2"),
protoTCP, 12345, 443)
sess.Offer(pkt, true) // should not panic
assert.Equal(t, int64(0), sess.Stats().Packets)
}
func TestSession_Done(t *testing.T) {
var buf bytes.Buffer
sess, err := NewSession(Options{Output: &buf})
require.NoError(t, err)
select {
case <-sess.Done():
t.Fatal("Done should not be closed before Stop")
default:
}
sess.Stop()
select {
case <-sess.Done():
case <-time.After(time.Second):
t.Fatal("Done should be closed after Stop")
}
}
func TestSession_RequiresOutput(t *testing.T) {
_, err := NewSession(Options{})
assert.Error(t, err)
}

638
util/capture/text.go Normal file
View File

@@ -0,0 +1,638 @@
package capture
import (
"encoding/binary"
"fmt"
"io"
"net/netip"
"strings"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
// TextWriter writes human-readable one-line-per-packet summaries.
// It is not safe for concurrent use; callers must serialize access.
type TextWriter struct {
w io.Writer
verbose bool
ascii bool
flows map[dirKey]uint32
}
type dirKey struct {
src netip.AddrPort
dst netip.AddrPort
}
// NewTextWriter creates a text formatter that writes to w.
func NewTextWriter(w io.Writer, verbose, ascii bool) *TextWriter {
return &TextWriter{
w: w,
verbose: verbose,
ascii: ascii,
flows: make(map[dirKey]uint32),
}
}
// tag formats the fixed-width "[DIR PROTO]" prefix with right-aligned protocol.
func tag(dir Direction, proto string) string {
return fmt.Sprintf("[%-3s %4s]", dir, proto)
}
// WritePacket formats and writes a single packet line.
func (tw *TextWriter) WritePacket(ts time.Time, data []byte, dir Direction) error {
ts = ts.Local()
info, ok := parsePacketInfo(data)
if !ok {
_, err := fmt.Fprintf(tw.w, "%s [%-3s ?] ??? len=%d\n",
ts.Format("15:04:05.000000"), dir, len(data))
return err
}
timeStr := ts.Format("15:04:05.000000")
var err error
switch info.proto {
case protoTCP:
err = tw.writeTCP(timeStr, dir, &info, data)
case protoUDP:
err = tw.writeUDP(timeStr, dir, &info, data)
case protoICMP:
err = tw.writeICMPv4(timeStr, dir, &info, data)
case protoICMPv6:
err = tw.writeICMPv6(timeStr, dir, &info, data)
default:
var verbose string
if tw.verbose {
verbose = tw.verboseIP(data, info.family)
}
_, err = fmt.Fprintf(tw.w, "%s %s %s > %s length %d%s\n",
timeStr, tag(dir, fmt.Sprintf("P%d", info.proto)),
info.srcIP, info.dstIP, len(data)-info.hdrLen, verbose)
}
return err
}
func (tw *TextWriter) writeTCP(timeStr string, dir Direction, info *packetInfo, data []byte) error {
tcp := &layers.TCP{}
if err := tcp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil {
return tw.writeFallback(timeStr, dir, "TCP", info, data)
}
flags := tcpFlagsStr(tcp)
plen := len(tcp.Payload)
// Protocol annotation
var annotation string
if plen > 0 {
annotation = annotatePayload(tcp.Payload)
}
if !tw.verbose {
_, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d [%s] length %d%s\n",
timeStr, tag(dir, "TCP"),
info.srcIP, info.srcPort, info.dstIP, info.dstPort,
flags, plen, annotation)
if err != nil {
return err
}
if tw.ascii && plen > 0 {
return tw.writeASCII(tcp.Payload)
}
return nil
}
relSeq, relAck := tw.relativeSeqAck(info, tcp.Seq, tcp.Ack)
var seqStr string
if plen > 0 {
seqStr = fmt.Sprintf(", seq %d:%d", relSeq, relSeq+uint32(plen))
} else {
seqStr = fmt.Sprintf(", seq %d", relSeq)
}
var ackStr string
if tcp.ACK {
ackStr = fmt.Sprintf(", ack %d", relAck)
}
var opts string
if s := formatTCPOptions(tcp.Options); s != "" {
opts = ", options [" + s + "]"
}
verbose := tw.verboseIP(data, info.family)
_, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d [%s]%s%s, win %d%s, length %d%s%s\n",
timeStr, tag(dir, "TCP"),
info.srcIP, info.srcPort, info.dstIP, info.dstPort,
flags, seqStr, ackStr, tcp.Window, opts, plen, annotation, verbose)
if err != nil {
return err
}
if tw.ascii && plen > 0 {
return tw.writeASCII(tcp.Payload)
}
return nil
}
func (tw *TextWriter) writeUDP(timeStr string, dir Direction, info *packetInfo, data []byte) error {
udp := &layers.UDP{}
if err := udp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil {
return tw.writeFallback(timeStr, dir, "UDP", info, data)
}
plen := len(udp.Payload)
// DNS replaces the entire line format
if plen > 0 && isDNSPort(info.srcPort, info.dstPort) {
if s := formatDNSPayload(udp.Payload); s != "" {
var verbose string
if tw.verbose {
verbose = tw.verboseIP(data, info.family)
}
_, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d %s%s\n",
timeStr, tag(dir, "UDP"),
info.srcIP, info.srcPort, info.dstIP, info.dstPort,
s, verbose)
return err
}
}
var verbose string
if tw.verbose {
verbose = tw.verboseIP(data, info.family)
}
_, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d length %d%s\n",
timeStr, tag(dir, "UDP"),
info.srcIP, info.srcPort, info.dstIP, info.dstPort,
plen, verbose)
if err != nil {
return err
}
if tw.ascii && plen > 0 {
return tw.writeASCII(udp.Payload)
}
return nil
}
func (tw *TextWriter) writeICMPv4(timeStr string, dir Direction, info *packetInfo, data []byte) error {
icmp := &layers.ICMPv4{}
if err := icmp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil {
return tw.writeFallback(timeStr, dir, "ICMP", info, data)
}
var detail string
if icmp.TypeCode.Type() == layers.ICMPv4TypeEchoRequest || icmp.TypeCode.Type() == layers.ICMPv4TypeEchoReply {
detail = fmt.Sprintf("%s, id %d, seq %d", icmp.TypeCode.String(), icmp.Id, icmp.Seq)
} else {
detail = icmp.TypeCode.String()
}
var verbose string
if tw.verbose {
verbose = tw.verboseIP(data, info.family)
}
_, err := fmt.Fprintf(tw.w, "%s %s %s > %s %s, length %d%s\n",
timeStr, tag(dir, "ICMP"), info.srcIP, info.dstIP, detail, len(data)-info.hdrLen, verbose)
return err
}
func (tw *TextWriter) writeICMPv6(timeStr string, dir Direction, info *packetInfo, data []byte) error {
icmp := &layers.ICMPv6{}
if err := icmp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil {
return tw.writeFallback(timeStr, dir, "ICMP", info, data)
}
var verbose string
if tw.verbose {
verbose = tw.verboseIP(data, info.family)
}
_, err := fmt.Fprintf(tw.w, "%s %s %s > %s %s, length %d%s\n",
timeStr, tag(dir, "ICMP"), info.srcIP, info.dstIP, icmp.TypeCode.String(), len(data)-info.hdrLen, verbose)
return err
}
func (tw *TextWriter) writeFallback(timeStr string, dir Direction, proto string, info *packetInfo, data []byte) error {
_, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d length %d\n",
timeStr, tag(dir, proto),
info.srcIP, info.srcPort, info.dstIP, info.dstPort,
len(data)-info.hdrLen)
return err
}
func (tw *TextWriter) verboseIP(data []byte, family uint8) string {
return fmt.Sprintf(", ttl %d, id %d, iplen %d",
ipTTL(data, family), ipID(data, family), len(data))
}
// relativeSeqAck returns seq/ack relative to the first seen value per direction.
func (tw *TextWriter) relativeSeqAck(info *packetInfo, seq, ack uint32) (relSeq, relAck uint32) {
fwd := dirKey{
src: netip.AddrPortFrom(info.srcIP, info.srcPort),
dst: netip.AddrPortFrom(info.dstIP, info.dstPort),
}
rev := dirKey{
src: netip.AddrPortFrom(info.dstIP, info.dstPort),
dst: netip.AddrPortFrom(info.srcIP, info.srcPort),
}
if isn, ok := tw.flows[fwd]; ok {
relSeq = seq - isn
} else {
tw.flows[fwd] = seq
}
if isn, ok := tw.flows[rev]; ok {
relAck = ack - isn
} else {
relAck = ack
}
return relSeq, relAck
}
// writeASCII prints payload bytes as printable ASCII.
func (tw *TextWriter) writeASCII(payload []byte) error {
if len(payload) == 0 {
return nil
}
buf := make([]byte, len(payload))
for i, b := range payload {
switch {
case b >= 0x20 && b < 0x7f:
buf[i] = b
case b == '\n' || b == '\r' || b == '\t':
buf[i] = b
default:
buf[i] = '.'
}
}
_, err := fmt.Fprintf(tw.w, "%s\n", buf)
return err
}
// --- TCP helpers ---
func ipTTL(data []byte, family uint8) uint8 {
if family == 4 && len(data) > 8 {
return data[8]
}
if family == 6 && len(data) > 7 {
return data[7]
}
return 0
}
func ipID(data []byte, family uint8) uint16 {
if family == 4 && len(data) >= 6 {
return binary.BigEndian.Uint16(data[4:6])
}
return 0
}
func tcpFlagsStr(tcp *layers.TCP) string {
var buf [6]byte
n := 0
if tcp.SYN {
buf[n] = 'S'
n++
}
if tcp.FIN {
buf[n] = 'F'
n++
}
if tcp.RST {
buf[n] = 'R'
n++
}
if tcp.PSH {
buf[n] = 'P'
n++
}
if tcp.ACK {
buf[n] = '.'
n++
}
if tcp.URG {
buf[n] = 'U'
n++
}
if n == 0 {
return "none"
}
return string(buf[:n])
}
func formatTCPOptions(opts []layers.TCPOption) string {
var parts []string
for _, opt := range opts {
switch opt.OptionType {
case layers.TCPOptionKindEndList:
return strings.Join(parts, ",")
case layers.TCPOptionKindNop:
parts = append(parts, "nop")
case layers.TCPOptionKindMSS:
if len(opt.OptionData) == 2 {
parts = append(parts, fmt.Sprintf("mss %d", binary.BigEndian.Uint16(opt.OptionData)))
}
case layers.TCPOptionKindWindowScale:
if len(opt.OptionData) == 1 {
parts = append(parts, fmt.Sprintf("wscale %d", opt.OptionData[0]))
}
case layers.TCPOptionKindSACKPermitted:
parts = append(parts, "sackOK")
case layers.TCPOptionKindSACK:
blocks := len(opt.OptionData) / 8
parts = append(parts, fmt.Sprintf("sack %d", blocks))
case layers.TCPOptionKindTimestamps:
if len(opt.OptionData) == 8 {
tsval := binary.BigEndian.Uint32(opt.OptionData[0:4])
tsecr := binary.BigEndian.Uint32(opt.OptionData[4:8])
parts = append(parts, fmt.Sprintf("TS val %d ecr %d", tsval, tsecr))
}
}
}
return strings.Join(parts, ",")
}
// --- Protocol annotation ---
// annotatePayload returns a protocol annotation string for known application protocols.
func annotatePayload(payload []byte) string {
if len(payload) < 4 {
return ""
}
s := string(payload)
// SSH banner: "SSH-2.0-OpenSSH_9.6\r\n"
if strings.HasPrefix(s, "SSH-") {
if end := strings.IndexByte(s, '\r'); end > 0 && end < 256 {
return ": " + s[:end]
}
}
// TLS records
if ann := annotateTLS(payload); ann != "" {
return ": " + ann
}
// HTTP request or response
for _, method := range [...]string{"GET ", "POST ", "PUT ", "DELETE ", "HEAD ", "PATCH ", "OPTIONS ", "CONNECT "} {
if strings.HasPrefix(s, method) {
if end := strings.IndexByte(s, '\r'); end > 0 && end < 200 {
return ": " + s[:end]
}
}
}
if strings.HasPrefix(s, "HTTP/") {
if end := strings.IndexByte(s, '\r'); end > 0 && end < 200 {
return ": " + s[:end]
}
}
return ""
}
// annotateTLS returns a description for TLS handshake and alert records.
func annotateTLS(data []byte) string {
if len(data) < 6 {
return ""
}
switch data[0] {
case 0x16:
return annotateTLSHandshake(data)
case 0x15:
return annotateTLSAlert(data)
}
return ""
}
func annotateTLSHandshake(data []byte) string {
if len(data) < 10 {
return ""
}
switch data[5] {
case 0x01:
if sni := extractSNI(data); sni != "" {
return "TLS ClientHello SNI=" + sni
}
return "TLS ClientHello"
case 0x02:
return "TLS ServerHello"
}
return ""
}
func annotateTLSAlert(data []byte) string {
if len(data) < 7 {
return ""
}
severity := "warning"
if data[5] == 2 {
severity = "fatal"
}
return fmt.Sprintf("TLS Alert %s %s", severity, tlsAlertDesc(data[6]))
}
func tlsAlertDesc(code byte) string {
switch code {
case 0:
return "close_notify"
case 10:
return "unexpected_message"
case 40:
return "handshake_failure"
case 42:
return "bad_certificate"
case 43:
return "unsupported_certificate"
case 44:
return "certificate_revoked"
case 45:
return "certificate_expired"
case 48:
return "unknown_ca"
case 49:
return "access_denied"
case 50:
return "decode_error"
case 70:
return "protocol_version"
case 80:
return "internal_error"
case 86:
return "inappropriate_fallback"
case 90:
return "user_canceled"
case 112:
return "unrecognized_name"
default:
return fmt.Sprintf("alert(%d)", code)
}
}
// extractSNI parses a TLS ClientHello and returns the SNI server name.
func extractSNI(data []byte) string {
if len(data) < 6 || data[0] != 0x16 {
return ""
}
recordLen := int(binary.BigEndian.Uint16(data[3:5]))
handshake := data[5:]
if len(handshake) > recordLen {
handshake = handshake[:recordLen]
}
if len(handshake) < 4 || handshake[0] != 0x01 {
return ""
}
hsLen := int(handshake[1])<<16 | int(handshake[2])<<8 | int(handshake[3])
body := handshake[4:]
if len(body) > hsLen {
body = body[:hsLen]
}
extPos := clientHelloExtensionsOffset(body)
if extPos < 0 {
return ""
}
return findSNIExtension(body, extPos)
}
// clientHelloExtensionsOffset returns the byte offset where extensions begin
// within the ClientHello body, or -1 if the body is too short.
func clientHelloExtensionsOffset(body []byte) int {
if len(body) < 38 {
return -1
}
pos := 34
if pos >= len(body) {
return -1
}
pos += 1 + int(body[pos]) // session ID
if pos+2 > len(body) {
return -1
}
pos += 2 + int(binary.BigEndian.Uint16(body[pos:pos+2])) // cipher suites
if pos >= len(body) {
return -1
}
pos += 1 + int(body[pos]) // compression methods
if pos+2 > len(body) {
return -1
}
return pos
}
func findSNIExtension(body []byte, pos int) string {
extLen := int(binary.BigEndian.Uint16(body[pos : pos+2]))
pos += 2
extEnd := pos + extLen
if extEnd > len(body) {
extEnd = len(body)
}
for pos+4 <= extEnd {
extType := binary.BigEndian.Uint16(body[pos : pos+2])
eLen := int(binary.BigEndian.Uint16(body[pos+2 : pos+4]))
pos += 4
if pos+eLen > extEnd {
break
}
if extType == 0 && eLen >= 5 {
nameLen := int(binary.BigEndian.Uint16(body[pos+3 : pos+5]))
if pos+5+nameLen <= extEnd {
return string(body[pos+5 : pos+5+nameLen])
}
}
pos += eLen
}
return ""
}
func isDNSPort(src, dst uint16) bool {
return src == 53 || dst == 53 || src == 5353 || dst == 5353
}
// formatDNSPayload parses DNS and returns a tcpdump-style summary.
func formatDNSPayload(payload []byte) string {
d := &layers.DNS{}
if err := d.DecodeFromBytes(payload, gopacket.NilDecodeFeedback); err != nil {
return ""
}
rd := ""
if d.RD {
rd = "+"
}
if !d.QR {
return formatDNSQuery(d, rd, len(payload))
}
return formatDNSResponse(d, rd, len(payload))
}
func formatDNSQuery(d *layers.DNS, rd string, plen int) string {
if len(d.Questions) == 0 {
return fmt.Sprintf("%04x%s (%d)", d.ID, rd, plen)
}
q := d.Questions[0]
return fmt.Sprintf("%04x%s %s? %s. (%d)", d.ID, rd, q.Type, q.Name, plen)
}
func formatDNSResponse(d *layers.DNS, rd string, plen int) string {
anCount := d.ANCount
nsCount := d.NSCount
arCount := d.ARCount
if d.ResponseCode != layers.DNSResponseCodeNoErr {
return fmt.Sprintf("%04x %d/%d/%d %s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, plen)
}
if anCount > 0 && len(d.Answers) > 0 {
rr := d.Answers[0]
if rdata := shortRData(&rr); rdata != "" {
return fmt.Sprintf("%04x %d/%d/%d %s %s (%d)", d.ID, anCount, nsCount, arCount, rr.Type, rdata, plen)
}
}
return fmt.Sprintf("%04x %d/%d/%d (%d)", d.ID, anCount, nsCount, arCount, plen)
}
func shortRData(rr *layers.DNSResourceRecord) string {
switch rr.Type {
case layers.DNSTypeA, layers.DNSTypeAAAA:
if rr.IP != nil {
return rr.IP.String()
}
case layers.DNSTypeCNAME:
if len(rr.CNAME) > 0 {
return string(rr.CNAME) + "."
}
case layers.DNSTypePTR:
if len(rr.PTR) > 0 {
return string(rr.PTR) + "."
}
case layers.DNSTypeNS:
if len(rr.NS) > 0 {
return string(rr.NS) + "."
}
case layers.DNSTypeMX:
return fmt.Sprintf("%d %s.", rr.MX.Preference, rr.MX.Name)
case layers.DNSTypeTXT:
if len(rr.TXTs) > 0 {
return fmt.Sprintf("%q", string(rr.TXTs[0]))
}
case layers.DNSTypeSRV:
return fmt.Sprintf("%d %d %d %s.", rr.SRV.Priority, rr.SRV.Weight, rr.SRV.Port, rr.SRV.Name)
}
return ""
}