mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Add packet capture to debug bundle and CLI
This commit is contained in:
193
util/capture/afpacket_linux.go
Normal file
193
util/capture/afpacket_linux.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// htons converts a uint16 from host to network (big-endian) byte order.
|
||||
func htons(v uint16) uint16 {
|
||||
var buf [2]byte
|
||||
binary.BigEndian.PutUint16(buf[:], v)
|
||||
return binary.NativeEndian.Uint16(buf[:])
|
||||
}
|
||||
|
||||
// AFPacketCapture reads raw packets from a network interface using an
|
||||
// AF_PACKET socket. This is the kernel-mode fallback when FilteredDevice is
|
||||
// not available (kernel WireGuard). Linux only.
|
||||
//
|
||||
// It implements device.PacketCapture so it can be set on a Session, but it
|
||||
// drives its own read loop rather than being called from FilteredDevice.
|
||||
// Call Start to begin and Stop to end.
|
||||
type AFPacketCapture struct {
|
||||
ifaceName string
|
||||
sess *Session
|
||||
fd int
|
||||
mu sync.Mutex
|
||||
stopped chan struct{}
|
||||
started atomic.Bool
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
// NewAFPacketCapture creates a capture bound to the given interface.
|
||||
// The session receives packets via Offer.
|
||||
func NewAFPacketCapture(ifaceName string, sess *Session) *AFPacketCapture {
|
||||
return &AFPacketCapture{
|
||||
ifaceName: ifaceName,
|
||||
sess: sess,
|
||||
fd: -1,
|
||||
stopped: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start opens the AF_PACKET socket and begins reading packets.
|
||||
// Packets are fed to the session via Offer. Returns immediately;
|
||||
// the read loop runs in a goroutine.
|
||||
func (c *AFPacketCapture) Start() error {
|
||||
if c.sess == nil {
|
||||
return errors.New("nil capture session")
|
||||
}
|
||||
if c.started.Load() {
|
||||
return errors.New("capture already started")
|
||||
}
|
||||
|
||||
iface, err := net.InterfaceByName(c.ifaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interface %s: %w", c.ifaceName, err)
|
||||
}
|
||||
|
||||
fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_DGRAM|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, int(htons(unix.ETH_P_ALL)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create AF_PACKET socket: %w", err)
|
||||
}
|
||||
|
||||
addr := &unix.SockaddrLinklayer{
|
||||
Protocol: htons(unix.ETH_P_ALL),
|
||||
Ifindex: iface.Index,
|
||||
}
|
||||
if err := unix.Bind(fd, addr); err != nil {
|
||||
unix.Close(fd)
|
||||
return fmt.Errorf("bind to %s: %w", c.ifaceName, err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.fd = fd
|
||||
c.mu.Unlock()
|
||||
|
||||
c.started.Store(true)
|
||||
go c.readLoop(fd)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop closes the socket and waits for the read loop to exit. Idempotent.
|
||||
func (c *AFPacketCapture) Stop() {
|
||||
if !c.closed.CompareAndSwap(false, true) {
|
||||
if c.started.Load() {
|
||||
<-c.stopped
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
fd := c.fd
|
||||
c.fd = -1
|
||||
c.mu.Unlock()
|
||||
|
||||
if fd >= 0 {
|
||||
unix.Close(fd)
|
||||
}
|
||||
|
||||
if c.started.Load() {
|
||||
<-c.stopped
|
||||
}
|
||||
}
|
||||
|
||||
func (c *AFPacketCapture) readLoop(fd int) {
|
||||
defer close(c.stopped)
|
||||
|
||||
buf := make([]byte, 65536)
|
||||
pollFds := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN}}
|
||||
|
||||
for {
|
||||
if c.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := c.pollOnce(pollFds)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
c.recvAndOffer(fd, buf)
|
||||
}
|
||||
}
|
||||
|
||||
// pollOnce waits for data on the fd. Returns true if data is ready, false for timeout/retry.
|
||||
// Returns an error to signal the loop should exit.
|
||||
func (c *AFPacketCapture) pollOnce(pollFds []unix.PollFd) (bool, error) {
|
||||
n, err := unix.Poll(pollFds, 200)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EINTR) {
|
||||
return false, nil
|
||||
}
|
||||
if c.closed.Load() {
|
||||
return false, errors.New("closed")
|
||||
}
|
||||
log.Debugf("af_packet poll: %v", err)
|
||||
return false, err
|
||||
}
|
||||
if n == 0 {
|
||||
return false, nil
|
||||
}
|
||||
if pollFds[0].Revents&(unix.POLLERR|unix.POLLHUP|unix.POLLNVAL) != 0 {
|
||||
return false, errors.New("fd error")
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (c *AFPacketCapture) recvAndOffer(fd int, buf []byte) {
|
||||
nr, from, err := unix.Recvfrom(fd, buf, 0)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) {
|
||||
return
|
||||
}
|
||||
if !c.closed.Load() {
|
||||
log.Debugf("af_packet recvfrom: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if nr < 1 {
|
||||
return
|
||||
}
|
||||
|
||||
ver := buf[0] >> 4
|
||||
if ver != 4 && ver != 6 {
|
||||
return
|
||||
}
|
||||
|
||||
// The kernel sets Pkttype on AF_PACKET sockets:
|
||||
// PACKET_HOST(0) = addressed to us (inbound)
|
||||
// PACKET_OUTGOING(4) = sent by us (outbound)
|
||||
outbound := false
|
||||
if sa, ok := from.(*unix.SockaddrLinklayer); ok {
|
||||
outbound = sa.Pkttype == unix.PACKET_OUTGOING
|
||||
}
|
||||
c.sess.Offer(buf[:nr], outbound)
|
||||
}
|
||||
|
||||
// Offer satisfies device.PacketCapture but is unused: the AFPacketCapture
|
||||
// drives its own read loop. This exists only so the type signature is
|
||||
// compatible if someone tries to set it as a PacketCapture.
|
||||
func (c *AFPacketCapture) Offer([]byte, bool) {
|
||||
// unused: AFPacketCapture drives its own read loop
|
||||
}
|
||||
26
util/capture/afpacket_stub.go
Normal file
26
util/capture/afpacket_stub.go
Normal file
@@ -0,0 +1,26 @@
|
||||
//go:build !linux
|
||||
|
||||
package capture
|
||||
|
||||
import "errors"
|
||||
|
||||
// AFPacketCapture is not available on this platform.
|
||||
type AFPacketCapture struct{}
|
||||
|
||||
// NewAFPacketCapture returns nil on non-Linux platforms.
|
||||
func NewAFPacketCapture(string, *Session) *AFPacketCapture { return nil }
|
||||
|
||||
// Start returns an error on non-Linux platforms.
|
||||
func (c *AFPacketCapture) Start() error {
|
||||
return errors.New("AF_PACKET capture is only supported on Linux")
|
||||
}
|
||||
|
||||
// Stop is a no-op on non-Linux platforms.
|
||||
func (c *AFPacketCapture) Stop() {
|
||||
// no-op on non-Linux platforms
|
||||
}
|
||||
|
||||
// Offer is a no-op on non-Linux platforms.
|
||||
func (c *AFPacketCapture) Offer([]byte, bool) {
|
||||
// no-op on non-Linux platforms
|
||||
}
|
||||
59
util/capture/capture.go
Normal file
59
util/capture/capture.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Package capture provides userspace packet capture in pcap format.
|
||||
//
|
||||
// It taps decrypted WireGuard packets flowing through the FilteredDevice and
|
||||
// writes them as pcap (readable by tcpdump, tshark, Wireshark) or as
|
||||
// human-readable one-line-per-packet text.
|
||||
package capture
|
||||
|
||||
import "io"
|
||||
|
||||
// Direction indicates whether a packet is entering or leaving the host.
|
||||
type Direction uint8
|
||||
|
||||
const (
|
||||
// Inbound is a packet arriving from the network (FilteredDevice.Write path).
|
||||
Inbound Direction = iota
|
||||
// Outbound is a packet leaving the host (FilteredDevice.Read path).
|
||||
Outbound
|
||||
)
|
||||
|
||||
// String returns "IN" or "OUT".
|
||||
func (d Direction) String() string {
|
||||
if d == Outbound {
|
||||
return "OUT"
|
||||
}
|
||||
return "IN"
|
||||
}
|
||||
|
||||
const (
|
||||
protoICMP = 1
|
||||
protoTCP = 6
|
||||
protoUDP = 17
|
||||
protoICMPv6 = 58
|
||||
)
|
||||
|
||||
// Options configures a capture session.
|
||||
type Options struct {
|
||||
// Output receives pcap-formatted data. Nil disables pcap output.
|
||||
Output io.Writer
|
||||
// TextOutput receives human-readable packet summaries. Nil disables text output.
|
||||
TextOutput io.Writer
|
||||
// Matcher selects which packets to capture. Nil captures all.
|
||||
// Use ParseFilter("host 10.0.0.1 and tcp") or &Filter{...}.
|
||||
Matcher Matcher
|
||||
// Verbose adds seq/ack, TTL, window, total length to text output.
|
||||
Verbose bool
|
||||
// ASCII dumps transport payload as printable ASCII after each packet line.
|
||||
ASCII bool
|
||||
// SnapLen is the maximum bytes captured per packet. 0 means 65535.
|
||||
SnapLen uint32
|
||||
// BufSize is the internal channel buffer size. 0 means 256.
|
||||
BufSize int
|
||||
}
|
||||
|
||||
// Stats reports capture session counters.
|
||||
type Stats struct {
|
||||
Packets int64
|
||||
Bytes int64
|
||||
Dropped int64
|
||||
}
|
||||
528
util/capture/filter.go
Normal file
528
util/capture/filter.go
Normal file
@@ -0,0 +1,528 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Matcher tests whether a raw packet should be captured.
|
||||
type Matcher interface {
|
||||
Match(data []byte) bool
|
||||
}
|
||||
|
||||
// Filter selects packets by flat AND'd criteria. Useful for structured APIs
|
||||
// (query params, proto fields). Implements Matcher.
|
||||
type Filter struct {
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
Host netip.Addr
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
Port uint16
|
||||
Proto uint8
|
||||
}
|
||||
|
||||
// IsEmpty returns true if the filter has no criteria set.
|
||||
func (f *Filter) IsEmpty() bool {
|
||||
return !f.SrcIP.IsValid() && !f.DstIP.IsValid() && !f.Host.IsValid() &&
|
||||
f.SrcPort == 0 && f.DstPort == 0 && f.Port == 0 && f.Proto == 0
|
||||
}
|
||||
|
||||
// Match implements Matcher. All non-zero fields must match (AND).
|
||||
func (f *Filter) Match(data []byte) bool {
|
||||
if f.IsEmpty() {
|
||||
return true
|
||||
}
|
||||
info, ok := parsePacketInfo(data)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if f.Host.IsValid() && info.srcIP != f.Host && info.dstIP != f.Host {
|
||||
return false
|
||||
}
|
||||
if f.SrcIP.IsValid() && info.srcIP != f.SrcIP {
|
||||
return false
|
||||
}
|
||||
if f.DstIP.IsValid() && info.dstIP != f.DstIP {
|
||||
return false
|
||||
}
|
||||
if f.Proto != 0 && info.proto != f.Proto {
|
||||
return false
|
||||
}
|
||||
if f.Port != 0 && info.srcPort != f.Port && info.dstPort != f.Port {
|
||||
return false
|
||||
}
|
||||
if f.SrcPort != 0 && info.srcPort != f.SrcPort {
|
||||
return false
|
||||
}
|
||||
if f.DstPort != 0 && info.dstPort != f.DstPort {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// exprNode evaluates a filter condition against pre-parsed packet info.
|
||||
type exprNode func(info *packetInfo) bool
|
||||
|
||||
// exprMatcher wraps an expression tree. Parses the packet once, then walks the tree.
|
||||
type exprMatcher struct {
|
||||
root exprNode
|
||||
}
|
||||
|
||||
func (m *exprMatcher) Match(data []byte) bool {
|
||||
info, ok := parsePacketInfo(data)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return m.root(&info)
|
||||
}
|
||||
|
||||
func nodeAnd(a, b exprNode) exprNode {
|
||||
return func(info *packetInfo) bool { return a(info) && b(info) }
|
||||
}
|
||||
|
||||
func nodeOr(a, b exprNode) exprNode {
|
||||
return func(info *packetInfo) bool { return a(info) || b(info) }
|
||||
}
|
||||
|
||||
func nodeNot(n exprNode) exprNode {
|
||||
return func(info *packetInfo) bool { return !n(info) }
|
||||
}
|
||||
|
||||
func nodeHost(addr netip.Addr) exprNode {
|
||||
return func(info *packetInfo) bool { return info.srcIP == addr || info.dstIP == addr }
|
||||
}
|
||||
|
||||
func nodeSrcHost(addr netip.Addr) exprNode {
|
||||
return func(info *packetInfo) bool { return info.srcIP == addr }
|
||||
}
|
||||
|
||||
func nodeDstHost(addr netip.Addr) exprNode {
|
||||
return func(info *packetInfo) bool { return info.dstIP == addr }
|
||||
}
|
||||
|
||||
func nodePort(port uint16) exprNode {
|
||||
return func(info *packetInfo) bool { return info.srcPort == port || info.dstPort == port }
|
||||
}
|
||||
|
||||
func nodeSrcPort(port uint16) exprNode {
|
||||
return func(info *packetInfo) bool { return info.srcPort == port }
|
||||
}
|
||||
|
||||
func nodeDstPort(port uint16) exprNode {
|
||||
return func(info *packetInfo) bool { return info.dstPort == port }
|
||||
}
|
||||
|
||||
func nodeProto(proto uint8) exprNode {
|
||||
return func(info *packetInfo) bool { return info.proto == proto }
|
||||
}
|
||||
|
||||
func nodeFamily(family uint8) exprNode {
|
||||
return func(info *packetInfo) bool { return info.family == family }
|
||||
}
|
||||
|
||||
func nodeNet(prefix netip.Prefix) exprNode {
|
||||
return func(info *packetInfo) bool { return prefix.Contains(info.srcIP) || prefix.Contains(info.dstIP) }
|
||||
}
|
||||
|
||||
func nodeSrcNet(prefix netip.Prefix) exprNode {
|
||||
return func(info *packetInfo) bool { return prefix.Contains(info.srcIP) }
|
||||
}
|
||||
|
||||
func nodeDstNet(prefix netip.Prefix) exprNode {
|
||||
return func(info *packetInfo) bool { return prefix.Contains(info.dstIP) }
|
||||
}
|
||||
|
||||
// packetInfo holds parsed header fields for filtering and display.
|
||||
type packetInfo struct {
|
||||
family uint8
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
proto uint8
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
hdrLen int
|
||||
}
|
||||
|
||||
func parsePacketInfo(data []byte) (packetInfo, bool) {
|
||||
if len(data) < 1 {
|
||||
return packetInfo{}, false
|
||||
}
|
||||
switch data[0] >> 4 {
|
||||
case 4:
|
||||
return parseIPv4Info(data)
|
||||
case 6:
|
||||
return parseIPv6Info(data)
|
||||
default:
|
||||
return packetInfo{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func parseIPv4Info(data []byte) (packetInfo, bool) {
|
||||
if len(data) < 20 {
|
||||
return packetInfo{}, false
|
||||
}
|
||||
ihl := int(data[0]&0x0f) * 4
|
||||
if ihl < 20 || len(data) < ihl {
|
||||
return packetInfo{}, false
|
||||
}
|
||||
info := packetInfo{
|
||||
family: 4,
|
||||
srcIP: netip.AddrFrom4([4]byte{data[12], data[13], data[14], data[15]}),
|
||||
dstIP: netip.AddrFrom4([4]byte{data[16], data[17], data[18], data[19]}),
|
||||
proto: data[9],
|
||||
hdrLen: ihl,
|
||||
}
|
||||
if (info.proto == protoTCP || info.proto == protoUDP) && len(data) >= ihl+4 {
|
||||
info.srcPort = binary.BigEndian.Uint16(data[ihl:])
|
||||
info.dstPort = binary.BigEndian.Uint16(data[ihl+2:])
|
||||
}
|
||||
return info, true
|
||||
}
|
||||
|
||||
// parseIPv6Info parses the fixed IPv6 header. It reads the Next Header field
|
||||
// directly, so packets with extension headers (hop-by-hop, routing, fragment,
|
||||
// etc.) will report the extension type as the protocol rather than the final
|
||||
// transport protocol. This is acceptable for a debug capture tool.
|
||||
func parseIPv6Info(data []byte) (packetInfo, bool) {
|
||||
if len(data) < 40 {
|
||||
return packetInfo{}, false
|
||||
}
|
||||
var src, dst [16]byte
|
||||
copy(src[:], data[8:24])
|
||||
copy(dst[:], data[24:40])
|
||||
info := packetInfo{
|
||||
family: 6,
|
||||
srcIP: netip.AddrFrom16(src),
|
||||
dstIP: netip.AddrFrom16(dst),
|
||||
proto: data[6],
|
||||
hdrLen: 40,
|
||||
}
|
||||
if (info.proto == protoTCP || info.proto == protoUDP) && len(data) >= 44 {
|
||||
info.srcPort = binary.BigEndian.Uint16(data[40:])
|
||||
info.dstPort = binary.BigEndian.Uint16(data[42:])
|
||||
}
|
||||
return info, true
|
||||
}
|
||||
|
||||
// ParseFilter parses a BPF-like filter expression and returns a Matcher.
|
||||
// Returns nil Matcher for an empty expression (match all).
|
||||
//
|
||||
// Grammar (mirrors common tcpdump BPF syntax):
|
||||
//
|
||||
// orExpr = andExpr ("or" andExpr)*
|
||||
// andExpr = unary ("and" unary)*
|
||||
// unary = "not" unary | "(" orExpr ")" | term
|
||||
//
|
||||
// term = "host" IP | "src" target | "dst" target
|
||||
// | "port" NUM | "net" PREFIX
|
||||
// | "tcp" | "udp" | "icmp" | "icmp6"
|
||||
// | "ip" | "ip6" | "proto" NUM
|
||||
// target = "host" IP | "port" NUM | "net" PREFIX | IP
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// host 10.0.0.1 and tcp port 443
|
||||
// not port 22
|
||||
// (host 10.0.0.1 or host 10.0.0.2) and tcp
|
||||
// ip6 and icmp6
|
||||
// net 10.0.0.0/24
|
||||
// src host 10.0.0.1 or dst port 80
|
||||
func ParseFilter(expr string) (Matcher, error) {
|
||||
tokens := tokenize(expr)
|
||||
if len(tokens) == 0 {
|
||||
return nil, nil //nolint:nilnil // nil Matcher means "match all"
|
||||
}
|
||||
|
||||
p := &parser{tokens: tokens}
|
||||
node, err := p.parseOr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p.pos < len(p.tokens) {
|
||||
return nil, fmt.Errorf("unexpected token %q at position %d", p.tokens[p.pos], p.pos)
|
||||
}
|
||||
return &exprMatcher{root: node}, nil
|
||||
}
|
||||
|
||||
func tokenize(expr string) []string {
|
||||
expr = strings.TrimSpace(expr)
|
||||
if expr == "" {
|
||||
return nil
|
||||
}
|
||||
// Split on whitespace but keep parens as separate tokens.
|
||||
var tokens []string
|
||||
for _, field := range strings.Fields(expr) {
|
||||
tokens = append(tokens, splitParens(field)...)
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// splitParens splits "(foo)" into "(", "foo", ")".
|
||||
func splitParens(s string) []string {
|
||||
var out []string
|
||||
for strings.HasPrefix(s, "(") {
|
||||
out = append(out, "(")
|
||||
s = s[1:]
|
||||
}
|
||||
var trail []string
|
||||
for strings.HasSuffix(s, ")") {
|
||||
trail = append(trail, ")")
|
||||
s = s[:len(s)-1]
|
||||
}
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
out = append(out, trail...)
|
||||
return out
|
||||
}
|
||||
|
||||
type parser struct {
|
||||
tokens []string
|
||||
pos int
|
||||
}
|
||||
|
||||
func (p *parser) peek() string {
|
||||
if p.pos >= len(p.tokens) {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(p.tokens[p.pos])
|
||||
}
|
||||
|
||||
func (p *parser) next() string {
|
||||
tok := p.peek()
|
||||
if tok != "" {
|
||||
p.pos++
|
||||
}
|
||||
return tok
|
||||
}
|
||||
|
||||
func (p *parser) expect(tok string) error {
|
||||
got := p.next()
|
||||
if got != tok {
|
||||
return fmt.Errorf("expected %q, got %q", tok, got)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *parser) parseOr() (exprNode, error) {
|
||||
left, err := p.parseAnd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for p.peek() == "or" {
|
||||
p.next()
|
||||
right, err := p.parseAnd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
left = nodeOr(left, right)
|
||||
}
|
||||
return left, nil
|
||||
}
|
||||
|
||||
func (p *parser) parseAnd() (exprNode, error) {
|
||||
left, err := p.parseUnary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for {
|
||||
tok := p.peek()
|
||||
if tok == "and" {
|
||||
p.next()
|
||||
right, err := p.parseUnary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
left = nodeAnd(left, right)
|
||||
continue
|
||||
}
|
||||
// Implicit AND: two atoms without "and" between them.
|
||||
// Only if the next token starts an atom (not "or", ")", or EOF).
|
||||
if tok != "" && tok != "or" && tok != ")" {
|
||||
right, err := p.parseUnary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
left = nodeAnd(left, right)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return left, nil
|
||||
}
|
||||
|
||||
func (p *parser) parseUnary() (exprNode, error) {
|
||||
switch p.peek() {
|
||||
case "not":
|
||||
p.next()
|
||||
inner, err := p.parseUnary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nodeNot(inner), nil
|
||||
case "(":
|
||||
p.next()
|
||||
inner, err := p.parseOr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := p.expect(")"); err != nil {
|
||||
return nil, fmt.Errorf("unclosed parenthesis")
|
||||
}
|
||||
return inner, nil
|
||||
default:
|
||||
return p.parseAtom()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parser) parseAtom() (exprNode, error) {
|
||||
tok := p.next()
|
||||
if tok == "" {
|
||||
return nil, fmt.Errorf("unexpected end of expression")
|
||||
}
|
||||
|
||||
switch tok {
|
||||
case "host":
|
||||
addr, err := p.parseAddr()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("host: %w", err)
|
||||
}
|
||||
return nodeHost(addr), nil
|
||||
|
||||
case "port":
|
||||
port, err := p.parsePort()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("port: %w", err)
|
||||
}
|
||||
return nodePort(port), nil
|
||||
|
||||
case "net":
|
||||
prefix, err := p.parsePrefix()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("net: %w", err)
|
||||
}
|
||||
return nodeNet(prefix), nil
|
||||
|
||||
case "src":
|
||||
return p.parseDirTarget(true)
|
||||
|
||||
case "dst":
|
||||
return p.parseDirTarget(false)
|
||||
|
||||
case "tcp":
|
||||
return nodeProto(protoTCP), nil
|
||||
case "udp":
|
||||
return nodeProto(protoUDP), nil
|
||||
case "icmp":
|
||||
return nodeProto(protoICMP), nil
|
||||
case "icmp6":
|
||||
return nodeProto(protoICMPv6), nil
|
||||
case "ip":
|
||||
return nodeFamily(4), nil
|
||||
case "ip6":
|
||||
return nodeFamily(6), nil
|
||||
|
||||
case "proto":
|
||||
raw := p.next()
|
||||
if raw == "" {
|
||||
return nil, fmt.Errorf("proto: missing number")
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err != nil || n < 0 || n > 255 {
|
||||
return nil, fmt.Errorf("proto: invalid number %q", raw)
|
||||
}
|
||||
return nodeProto(uint8(n)), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown filter keyword %q", tok)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parser) parseDirTarget(isSrc bool) (exprNode, error) {
|
||||
tok := p.peek()
|
||||
switch tok {
|
||||
case "host":
|
||||
p.next()
|
||||
addr, err := p.parseAddr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isSrc {
|
||||
return nodeSrcHost(addr), nil
|
||||
}
|
||||
return nodeDstHost(addr), nil
|
||||
|
||||
case "port":
|
||||
p.next()
|
||||
port, err := p.parsePort()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isSrc {
|
||||
return nodeSrcPort(port), nil
|
||||
}
|
||||
return nodeDstPort(port), nil
|
||||
|
||||
case "net":
|
||||
p.next()
|
||||
prefix, err := p.parsePrefix()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isSrc {
|
||||
return nodeSrcNet(prefix), nil
|
||||
}
|
||||
return nodeDstNet(prefix), nil
|
||||
|
||||
default:
|
||||
// Try as bare IP: "src 10.0.0.1"
|
||||
addr, err := p.parseAddr()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expected host, port, net, or IP after src/dst, got %q", tok)
|
||||
}
|
||||
if isSrc {
|
||||
return nodeSrcHost(addr), nil
|
||||
}
|
||||
return nodeDstHost(addr), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parser) parseAddr() (netip.Addr, error) {
|
||||
raw := p.next()
|
||||
if raw == "" {
|
||||
return netip.Addr{}, fmt.Errorf("missing IP address")
|
||||
}
|
||||
addr, err := netip.ParseAddr(raw)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("invalid IP %q", raw)
|
||||
}
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
func (p *parser) parsePort() (uint16, error) {
|
||||
raw := p.next()
|
||||
if raw == "" {
|
||||
return 0, fmt.Errorf("missing port number")
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err != nil || n < 1 || n > 65535 {
|
||||
return 0, fmt.Errorf("invalid port %q", raw)
|
||||
}
|
||||
return uint16(n), nil
|
||||
}
|
||||
|
||||
func (p *parser) parsePrefix() (netip.Prefix, error) {
|
||||
raw := p.next()
|
||||
if raw == "" {
|
||||
return netip.Prefix{}, fmt.Errorf("missing network prefix")
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(raw)
|
||||
if err != nil {
|
||||
return netip.Prefix{}, fmt.Errorf("invalid prefix %q", raw)
|
||||
}
|
||||
return prefix, nil
|
||||
}
|
||||
263
util/capture/filter_test.go
Normal file
263
util/capture/filter_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// buildIPv4Packet creates a minimal IPv4+TCP/UDP packet for filter testing.
|
||||
func buildIPv4Packet(t *testing.T, srcIP, dstIP netip.Addr, proto uint8, srcPort, dstPort uint16) []byte {
|
||||
t.Helper()
|
||||
|
||||
hdrLen := 20
|
||||
pkt := make([]byte, hdrLen+20)
|
||||
pkt[0] = 0x45
|
||||
pkt[9] = proto
|
||||
|
||||
src := srcIP.As4()
|
||||
dst := dstIP.As4()
|
||||
copy(pkt[12:16], src[:])
|
||||
copy(pkt[16:20], dst[:])
|
||||
|
||||
pkt[20] = byte(srcPort >> 8)
|
||||
pkt[21] = byte(srcPort)
|
||||
pkt[22] = byte(dstPort >> 8)
|
||||
pkt[23] = byte(dstPort)
|
||||
|
||||
return pkt
|
||||
}
|
||||
|
||||
// buildIPv6Packet creates a minimal IPv6+TCP/UDP packet for filter testing.
|
||||
func buildIPv6Packet(t *testing.T, srcIP, dstIP netip.Addr, proto uint8, srcPort, dstPort uint16) []byte {
|
||||
t.Helper()
|
||||
|
||||
pkt := make([]byte, 44) // 40 header + 4 ports
|
||||
pkt[0] = 0x60 // version 6
|
||||
pkt[6] = proto // next header
|
||||
|
||||
src := srcIP.As16()
|
||||
dst := dstIP.As16()
|
||||
copy(pkt[8:24], src[:])
|
||||
copy(pkt[24:40], dst[:])
|
||||
|
||||
pkt[40] = byte(srcPort >> 8)
|
||||
pkt[41] = byte(srcPort)
|
||||
pkt[42] = byte(dstPort >> 8)
|
||||
pkt[43] = byte(dstPort)
|
||||
|
||||
return pkt
|
||||
}
|
||||
|
||||
// ---- Filter struct tests ----
|
||||
|
||||
func TestFilter_Empty(t *testing.T) {
|
||||
f := Filter{}
|
||||
assert.True(t, f.IsEmpty())
|
||||
assert.True(t, f.Match(buildIPv4Packet(t,
|
||||
netip.MustParseAddr("10.0.0.1"),
|
||||
netip.MustParseAddr("10.0.0.2"),
|
||||
protoTCP, 12345, 443)))
|
||||
}
|
||||
|
||||
func TestFilter_Host(t *testing.T) {
|
||||
f := Filter{Host: netip.MustParseAddr("10.0.0.1")}
|
||||
assert.True(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), protoTCP, 1234, 80)))
|
||||
assert.True(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.1"), protoTCP, 1234, 80)))
|
||||
assert.False(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.3"), protoTCP, 1234, 80)))
|
||||
}
|
||||
|
||||
func TestFilter_InvalidPacket(t *testing.T) {
|
||||
f := Filter{Host: netip.MustParseAddr("10.0.0.1")}
|
||||
assert.False(t, f.Match(nil))
|
||||
assert.False(t, f.Match([]byte{}))
|
||||
assert.False(t, f.Match([]byte{0x00}))
|
||||
}
|
||||
|
||||
func TestParsePacketInfo_IPv4(t *testing.T) {
|
||||
pkt := buildIPv4Packet(t, netip.MustParseAddr("192.168.1.1"), netip.MustParseAddr("10.0.0.1"), protoTCP, 54321, 80)
|
||||
info, ok := parsePacketInfo(pkt)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, uint8(4), info.family)
|
||||
assert.Equal(t, netip.MustParseAddr("192.168.1.1"), info.srcIP)
|
||||
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), info.dstIP)
|
||||
assert.Equal(t, uint8(protoTCP), info.proto)
|
||||
assert.Equal(t, uint16(54321), info.srcPort)
|
||||
assert.Equal(t, uint16(80), info.dstPort)
|
||||
}
|
||||
|
||||
func TestParsePacketInfo_IPv6(t *testing.T) {
|
||||
pkt := buildIPv6Packet(t, netip.MustParseAddr("fd00::1"), netip.MustParseAddr("fd00::2"), protoUDP, 1234, 53)
|
||||
info, ok := parsePacketInfo(pkt)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, uint8(6), info.family)
|
||||
assert.Equal(t, netip.MustParseAddr("fd00::1"), info.srcIP)
|
||||
assert.Equal(t, netip.MustParseAddr("fd00::2"), info.dstIP)
|
||||
assert.Equal(t, uint8(protoUDP), info.proto)
|
||||
assert.Equal(t, uint16(1234), info.srcPort)
|
||||
assert.Equal(t, uint16(53), info.dstPort)
|
||||
}
|
||||
|
||||
// ---- ParseFilter expression tests ----
|
||||
|
||||
func matchV4(t *testing.T, m Matcher, srcIP, dstIP string, proto uint8, srcPort, dstPort uint16) bool {
|
||||
t.Helper()
|
||||
return m.Match(buildIPv4Packet(t, netip.MustParseAddr(srcIP), netip.MustParseAddr(dstIP), proto, srcPort, dstPort))
|
||||
}
|
||||
|
||||
func matchV6(t *testing.T, m Matcher, srcIP, dstIP string, proto uint8, srcPort, dstPort uint16) bool {
|
||||
t.Helper()
|
||||
return m.Match(buildIPv6Packet(t, netip.MustParseAddr(srcIP), netip.MustParseAddr(dstIP), proto, srcPort, dstPort))
|
||||
}
|
||||
|
||||
func TestParseFilter_Empty(t *testing.T) {
|
||||
m, err := ParseFilter("")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, m, "empty expression should return nil matcher")
|
||||
}
|
||||
|
||||
func TestParseFilter_Atoms(t *testing.T) {
|
||||
tests := []struct {
|
||||
expr string
|
||||
match bool
|
||||
}{
|
||||
{"tcp", true},
|
||||
{"udp", false},
|
||||
{"host 10.0.0.1", true},
|
||||
{"host 10.0.0.99", false},
|
||||
{"port 443", true},
|
||||
{"port 80", false},
|
||||
{"src host 10.0.0.1", true},
|
||||
{"dst host 10.0.0.2", true},
|
||||
{"dst host 10.0.0.1", false},
|
||||
{"src port 12345", true},
|
||||
{"dst port 443", true},
|
||||
{"dst port 80", false},
|
||||
{"proto 6", true},
|
||||
{"proto 17", false},
|
||||
}
|
||||
|
||||
pkt := buildIPv4Packet(t, netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), protoTCP, 12345, 443)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.expr, func(t *testing.T) {
|
||||
m, err := ParseFilter(tt.expr)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.match, m.Match(pkt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFilter_And(t *testing.T) {
|
||||
m, err := ParseFilter("host 10.0.0.1 and tcp port 443")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 443))
|
||||
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoUDP, 55555, 443), "wrong proto")
|
||||
assert.False(t, matchV4(t, m, "10.0.0.3", "10.0.0.2", protoTCP, 55555, 443), "wrong host")
|
||||
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 80), "wrong port")
|
||||
}
|
||||
|
||||
func TestParseFilter_ImplicitAnd(t *testing.T) {
|
||||
// "tcp port 443" = implicit AND between tcp and port 443
|
||||
m, err := ParseFilter("tcp port 443")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443))
|
||||
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoUDP, 1, 443))
|
||||
}
|
||||
|
||||
func TestParseFilter_Or(t *testing.T) {
|
||||
m, err := ParseFilter("port 80 or port 443")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 80))
|
||||
assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 443))
|
||||
assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 8080))
|
||||
}
|
||||
|
||||
func TestParseFilter_Not(t *testing.T) {
|
||||
m, err := ParseFilter("not port 22")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443))
|
||||
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 22))
|
||||
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 22, 80))
|
||||
}
|
||||
|
||||
func TestParseFilter_Parens(t *testing.T) {
|
||||
m, err := ParseFilter("(port 80 or port 443) and tcp")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 443))
|
||||
assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoUDP, 1, 443), "wrong proto")
|
||||
assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 8080), "wrong port")
|
||||
}
|
||||
|
||||
func TestParseFilter_Family(t *testing.T) {
|
||||
mV4, err := ParseFilter("ip")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, mV4, "10.0.0.1", "10.0.0.2", protoTCP, 1, 80))
|
||||
assert.False(t, matchV6(t, mV4, "fd00::1", "fd00::2", protoTCP, 1, 80))
|
||||
|
||||
mV6, err := ParseFilter("ip6")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, matchV4(t, mV6, "10.0.0.1", "10.0.0.2", protoTCP, 1, 80))
|
||||
assert.True(t, matchV6(t, mV6, "fd00::1", "fd00::2", protoTCP, 1, 80))
|
||||
}
|
||||
|
||||
func TestParseFilter_Net(t *testing.T) {
|
||||
m, err := ParseFilter("net 10.0.0.0/24")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "10.0.0.1", "192.168.1.1", protoTCP, 1, 80), "src in net")
|
||||
assert.True(t, matchV4(t, m, "192.168.1.1", "10.0.0.200", protoTCP, 1, 80), "dst in net")
|
||||
assert.False(t, matchV4(t, m, "10.0.1.1", "192.168.1.1", protoTCP, 1, 80), "neither in net")
|
||||
}
|
||||
|
||||
func TestParseFilter_SrcDstNet(t *testing.T) {
|
||||
m, err := ParseFilter("src net 10.0.0.0/8 and dst net 192.168.0.0/16")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "10.1.2.3", "192.168.1.1", protoTCP, 1, 80))
|
||||
assert.False(t, matchV4(t, m, "192.168.1.1", "10.1.2.3", protoTCP, 1, 80), "reversed")
|
||||
}
|
||||
|
||||
func TestParseFilter_Complex(t *testing.T) {
|
||||
// Real-world: capture HTTP(S) traffic to/from specific host, excluding SSH
|
||||
m, err := ParseFilter("host 10.0.0.1 and (port 80 or port 443) and not port 22")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 443))
|
||||
assert.True(t, matchV4(t, m, "10.0.0.2", "10.0.0.1", protoTCP, 55555, 80))
|
||||
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 22, 443), "port 22 excluded")
|
||||
assert.False(t, matchV4(t, m, "10.0.0.3", "10.0.0.2", protoTCP, 55555, 443), "wrong host")
|
||||
}
|
||||
|
||||
func TestParseFilter_IPv6Combined(t *testing.T) {
|
||||
m, err := ParseFilter("ip6 and icmp6")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV6(t, m, "fd00::1", "fd00::2", protoICMPv6, 0, 0))
|
||||
assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoICMP, 0, 0), "wrong family")
|
||||
assert.False(t, matchV6(t, m, "fd00::1", "fd00::2", protoTCP, 1, 80), "wrong proto")
|
||||
}
|
||||
|
||||
func TestParseFilter_CaseInsensitive(t *testing.T) {
|
||||
m, err := ParseFilter("HOST 10.0.0.1 AND TCP PORT 443")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443))
|
||||
}
|
||||
|
||||
func TestParseFilter_Errors(t *testing.T) {
|
||||
bad := []string{
|
||||
"badkeyword",
|
||||
"host",
|
||||
"port abc",
|
||||
"port 99999",
|
||||
"net invalid",
|
||||
"(",
|
||||
"(port 80",
|
||||
"not",
|
||||
"src",
|
||||
}
|
||||
for _, expr := range bad {
|
||||
t.Run(expr, func(t *testing.T) {
|
||||
_, err := ParseFilter(expr)
|
||||
assert.Error(t, err, "should fail for %q", expr)
|
||||
})
|
||||
}
|
||||
}
|
||||
85
util/capture/pcap.go
Normal file
85
util/capture/pcap.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
pcapMagic = 0xa1b2c3d4
|
||||
pcapVersionMaj = 2
|
||||
pcapVersionMin = 4
|
||||
// linkTypeRaw is LINKTYPE_RAW: raw IPv4/IPv6 packets without link-layer header.
|
||||
linkTypeRaw = 101
|
||||
defaultSnapLen = 65535
|
||||
)
|
||||
|
||||
// PcapWriter writes packets in pcap format to an underlying writer.
|
||||
// The global header is written lazily on the first WritePacket call so that
|
||||
// the writer can be used with unbuffered io.Pipes without deadlocking.
|
||||
// It is not safe for concurrent use; callers must serialize access.
|
||||
type PcapWriter struct {
|
||||
w io.Writer
|
||||
snapLen uint32
|
||||
headerWritten bool
|
||||
}
|
||||
|
||||
// NewPcapWriter creates a pcap writer. The global header is deferred until the
|
||||
// first WritePacket call.
|
||||
func NewPcapWriter(w io.Writer, snapLen uint32) *PcapWriter {
|
||||
if snapLen == 0 {
|
||||
snapLen = defaultSnapLen
|
||||
}
|
||||
return &PcapWriter{w: w, snapLen: snapLen}
|
||||
}
|
||||
|
||||
// writeGlobalHeader writes the 24-byte pcap file header.
|
||||
func (pw *PcapWriter) writeGlobalHeader() error {
|
||||
var hdr [24]byte
|
||||
binary.LittleEndian.PutUint32(hdr[0:4], pcapMagic)
|
||||
binary.LittleEndian.PutUint16(hdr[4:6], pcapVersionMaj)
|
||||
binary.LittleEndian.PutUint16(hdr[6:8], pcapVersionMin)
|
||||
binary.LittleEndian.PutUint32(hdr[16:20], pw.snapLen)
|
||||
binary.LittleEndian.PutUint32(hdr[20:24], linkTypeRaw)
|
||||
|
||||
_, err := pw.w.Write(hdr[:])
|
||||
return err
|
||||
}
|
||||
|
||||
// WriteHeader writes the pcap global header. Safe to call multiple times.
|
||||
func (pw *PcapWriter) WriteHeader() error {
|
||||
if pw.headerWritten {
|
||||
return nil
|
||||
}
|
||||
if err := pw.writeGlobalHeader(); err != nil {
|
||||
return err
|
||||
}
|
||||
pw.headerWritten = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// WritePacket writes a single packet record, preceded by the global header
|
||||
// on the first call.
|
||||
func (pw *PcapWriter) WritePacket(ts time.Time, data []byte) error {
|
||||
if err := pw.WriteHeader(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
origLen := uint32(len(data))
|
||||
if origLen > pw.snapLen {
|
||||
data = data[:pw.snapLen]
|
||||
}
|
||||
|
||||
var hdr [16]byte
|
||||
binary.LittleEndian.PutUint32(hdr[0:4], uint32(ts.Unix()))
|
||||
binary.LittleEndian.PutUint32(hdr[4:8], uint32(ts.Nanosecond()/1000))
|
||||
binary.LittleEndian.PutUint32(hdr[8:12], uint32(len(data)))
|
||||
binary.LittleEndian.PutUint32(hdr[12:16], origLen)
|
||||
|
||||
if _, err := pw.w.Write(hdr[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := pw.w.Write(data)
|
||||
return err
|
||||
}
|
||||
68
util/capture/pcap_test.go
Normal file
68
util/capture/pcap_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPcapWriter_GlobalHeader(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
pw := NewPcapWriter(&buf, 0)
|
||||
|
||||
// Header is lazy, so write a dummy packet to trigger it.
|
||||
err := pw.WritePacket(time.Now(), []byte{0x45, 0, 0, 20, 0, 0, 0, 0, 64, 1, 0, 0, 10, 0, 0, 1, 10, 0, 0, 2})
|
||||
require.NoError(t, err)
|
||||
|
||||
data := buf.Bytes()
|
||||
require.GreaterOrEqual(t, len(data), 24, "should contain global header")
|
||||
|
||||
assert.Equal(t, uint32(pcapMagic), binary.LittleEndian.Uint32(data[0:4]), "magic number")
|
||||
assert.Equal(t, uint16(pcapVersionMaj), binary.LittleEndian.Uint16(data[4:6]), "version major")
|
||||
assert.Equal(t, uint16(pcapVersionMin), binary.LittleEndian.Uint16(data[6:8]), "version minor")
|
||||
assert.Equal(t, uint32(defaultSnapLen), binary.LittleEndian.Uint32(data[16:20]), "snap length")
|
||||
assert.Equal(t, uint32(linkTypeRaw), binary.LittleEndian.Uint32(data[20:24]), "link type")
|
||||
}
|
||||
|
||||
func TestPcapWriter_WritePacket(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
pw := NewPcapWriter(&buf, 100)
|
||||
|
||||
ts := time.Date(2025, 6, 15, 12, 30, 45, 123456000, time.UTC)
|
||||
payload := make([]byte, 50)
|
||||
for i := range payload {
|
||||
payload[i] = byte(i)
|
||||
}
|
||||
|
||||
err := pw.WritePacket(ts, payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
data := buf.Bytes()[24:] // skip global header
|
||||
require.Len(t, data, 16+50, "packet header + payload")
|
||||
|
||||
assert.Equal(t, uint32(ts.Unix()), binary.LittleEndian.Uint32(data[0:4]), "timestamp seconds")
|
||||
assert.Equal(t, uint32(123456), binary.LittleEndian.Uint32(data[4:8]), "timestamp microseconds")
|
||||
assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[8:12]), "included length")
|
||||
assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[12:16]), "original length")
|
||||
assert.Equal(t, payload, data[16:], "packet data")
|
||||
}
|
||||
|
||||
func TestPcapWriter_SnapLen(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
pw := NewPcapWriter(&buf, 10)
|
||||
|
||||
ts := time.Now()
|
||||
payload := make([]byte, 50)
|
||||
|
||||
err := pw.WritePacket(ts, payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
data := buf.Bytes()[24:]
|
||||
assert.Equal(t, uint32(10), binary.LittleEndian.Uint32(data[8:12]), "included length should be truncated")
|
||||
assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[12:16]), "original length preserved")
|
||||
assert.Len(t, data[16:], 10, "only snap_len bytes written")
|
||||
}
|
||||
213
util/capture/session.go
Normal file
213
util/capture/session.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultBufSize = 256
|
||||
|
||||
type packetEntry struct {
|
||||
ts time.Time
|
||||
data []byte
|
||||
dir Direction
|
||||
}
|
||||
|
||||
// Session manages an active packet capture. Packets are offered via Offer,
|
||||
// buffered in a channel, and written to configured sinks by a background
|
||||
// goroutine. This keeps the hot path (FilteredDevice.Read/Write) non-blocking.
|
||||
//
|
||||
// The caller must call Stop when done to flush remaining packets and release
|
||||
// resources.
|
||||
type Session struct {
|
||||
pcapW *PcapWriter
|
||||
textW *TextWriter
|
||||
matcher Matcher
|
||||
snapLen uint32
|
||||
flushFn func()
|
||||
|
||||
ch chan packetEntry
|
||||
done chan struct{}
|
||||
stopped chan struct{}
|
||||
|
||||
closeOnce sync.Once
|
||||
closed atomic.Bool
|
||||
packets atomic.Int64
|
||||
bytes atomic.Int64
|
||||
dropped atomic.Int64
|
||||
started time.Time
|
||||
}
|
||||
|
||||
// NewSession creates and starts a capture session. At least one of
|
||||
// Options.Output or Options.TextOutput must be non-nil.
|
||||
func NewSession(opts Options) (*Session, error) {
|
||||
if opts.Output == nil && opts.TextOutput == nil {
|
||||
return nil, fmt.Errorf("at least one output sink required")
|
||||
}
|
||||
|
||||
snapLen := opts.SnapLen
|
||||
if snapLen == 0 {
|
||||
snapLen = defaultSnapLen
|
||||
}
|
||||
|
||||
bufSize := opts.BufSize
|
||||
if bufSize <= 0 {
|
||||
bufSize = defaultBufSize
|
||||
}
|
||||
|
||||
s := &Session{
|
||||
matcher: opts.Matcher,
|
||||
snapLen: snapLen,
|
||||
ch: make(chan packetEntry, bufSize),
|
||||
done: make(chan struct{}),
|
||||
stopped: make(chan struct{}),
|
||||
started: time.Now(),
|
||||
}
|
||||
|
||||
if opts.Output != nil {
|
||||
s.pcapW = NewPcapWriter(opts.Output, snapLen)
|
||||
}
|
||||
if opts.TextOutput != nil {
|
||||
s.textW = NewTextWriter(opts.TextOutput, opts.Verbose, opts.ASCII)
|
||||
}
|
||||
|
||||
s.flushFn = buildFlushFn(opts.Output, opts.TextOutput)
|
||||
|
||||
go s.run()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Offer submits a packet for capture. It returns immediately and never blocks
|
||||
// the caller. If the internal buffer is full the packet is dropped silently.
|
||||
//
|
||||
// outbound should be true for packets leaving the host (FilteredDevice.Read
|
||||
// path) and false for packets arriving (FilteredDevice.Write path).
|
||||
//
|
||||
// Offer satisfies the device.PacketCapture interface.
|
||||
func (s *Session) Offer(data []byte, outbound bool) {
|
||||
if s.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
if s.matcher != nil && !s.matcher.Match(data) {
|
||||
return
|
||||
}
|
||||
|
||||
captureLen := len(data)
|
||||
if s.snapLen > 0 && uint32(captureLen) > s.snapLen {
|
||||
captureLen = int(s.snapLen)
|
||||
}
|
||||
|
||||
copied := make([]byte, captureLen)
|
||||
copy(copied, data)
|
||||
|
||||
dir := Inbound
|
||||
if outbound {
|
||||
dir = Outbound
|
||||
}
|
||||
|
||||
select {
|
||||
case s.ch <- packetEntry{ts: time.Now(), data: copied, dir: dir}:
|
||||
s.packets.Add(1)
|
||||
s.bytes.Add(int64(len(data)))
|
||||
default:
|
||||
s.dropped.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop signals the session to stop accepting packets, drains any buffered
|
||||
// packets to the sinks, and waits for the writer goroutine to exit.
|
||||
// It is safe to call multiple times.
|
||||
func (s *Session) Stop() {
|
||||
s.closeOnce.Do(func() {
|
||||
s.closed.Store(true)
|
||||
close(s.done)
|
||||
})
|
||||
<-s.stopped
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the session's writer goroutine
|
||||
// has fully exited and all buffered packets have been flushed.
|
||||
func (s *Session) Done() <-chan struct{} {
|
||||
return s.stopped
|
||||
}
|
||||
|
||||
// Stats returns current capture counters.
|
||||
func (s *Session) Stats() Stats {
|
||||
return Stats{
|
||||
Packets: s.packets.Load(),
|
||||
Bytes: s.bytes.Load(),
|
||||
Dropped: s.dropped.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) run() {
|
||||
defer close(s.stopped)
|
||||
|
||||
for {
|
||||
select {
|
||||
case pkt := <-s.ch:
|
||||
s.write(pkt)
|
||||
case <-s.done:
|
||||
s.drain()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) drain() {
|
||||
for {
|
||||
select {
|
||||
case pkt := <-s.ch:
|
||||
s.write(pkt)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) write(pkt packetEntry) {
|
||||
if s.pcapW != nil {
|
||||
// Best-effort: if the writer fails (broken pipe etc.), discard silently.
|
||||
_ = s.pcapW.WritePacket(pkt.ts, pkt.data)
|
||||
}
|
||||
if s.textW != nil {
|
||||
_ = s.textW.WritePacket(pkt.ts, pkt.data, pkt.dir)
|
||||
}
|
||||
s.flushFn()
|
||||
}
|
||||
|
||||
// buildFlushFn returns a function that flushes all writers that support it.
|
||||
// This covers http.Flusher and similar streaming writers.
|
||||
func buildFlushFn(writers ...any) func() {
|
||||
type flusher interface {
|
||||
Flush()
|
||||
}
|
||||
|
||||
var fns []func()
|
||||
for _, w := range writers {
|
||||
if w == nil {
|
||||
continue
|
||||
}
|
||||
if f, ok := w.(flusher); ok {
|
||||
fns = append(fns, f.Flush)
|
||||
}
|
||||
}
|
||||
|
||||
switch len(fns) {
|
||||
case 0:
|
||||
return func() {
|
||||
// no writers to flush
|
||||
}
|
||||
case 1:
|
||||
return fns[0]
|
||||
default:
|
||||
return func() {
|
||||
for _, fn := range fns {
|
||||
fn()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
144
util/capture/session_test.go
Normal file
144
util/capture/session_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSession_PcapOutput(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
sess, err := NewSession(Options{
|
||||
Output: &buf,
|
||||
BufSize: 16,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
pkt := buildIPv4Packet(t,
|
||||
netip.MustParseAddr("10.0.0.1"),
|
||||
netip.MustParseAddr("10.0.0.2"),
|
||||
protoTCP, 12345, 443)
|
||||
|
||||
sess.Offer(pkt, true)
|
||||
sess.Stop()
|
||||
|
||||
data := buf.Bytes()
|
||||
require.Greater(t, len(data), 24, "should have global header + at least one packet")
|
||||
|
||||
// Verify global header
|
||||
assert.Equal(t, uint32(pcapMagic), binary.LittleEndian.Uint32(data[0:4]))
|
||||
assert.Equal(t, uint32(linkTypeRaw), binary.LittleEndian.Uint32(data[20:24]))
|
||||
|
||||
// Verify packet record
|
||||
pktData := data[24:]
|
||||
inclLen := binary.LittleEndian.Uint32(pktData[8:12])
|
||||
assert.Equal(t, uint32(len(pkt)), inclLen)
|
||||
|
||||
stats := sess.Stats()
|
||||
assert.Equal(t, int64(1), stats.Packets)
|
||||
assert.Equal(t, int64(len(pkt)), stats.Bytes)
|
||||
assert.Equal(t, int64(0), stats.Dropped)
|
||||
}
|
||||
|
||||
func TestSession_TextOutput(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
sess, err := NewSession(Options{
|
||||
TextOutput: &buf,
|
||||
BufSize: 16,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
pkt := buildIPv4Packet(t,
|
||||
netip.MustParseAddr("10.0.0.1"),
|
||||
netip.MustParseAddr("10.0.0.2"),
|
||||
protoTCP, 12345, 443)
|
||||
|
||||
sess.Offer(pkt, false)
|
||||
sess.Stop()
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "TCP")
|
||||
assert.Contains(t, output, "10.0.0.1")
|
||||
assert.Contains(t, output, "10.0.0.2")
|
||||
assert.Contains(t, output, "443")
|
||||
assert.Contains(t, output, "[IN TCP]")
|
||||
}
|
||||
|
||||
func TestSession_Filter(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
sess, err := NewSession(Options{
|
||||
Output: &buf,
|
||||
Matcher: &Filter{Port: 443},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
pktMatch := buildIPv4Packet(t,
|
||||
netip.MustParseAddr("10.0.0.1"),
|
||||
netip.MustParseAddr("10.0.0.2"),
|
||||
protoTCP, 12345, 443)
|
||||
pktNoMatch := buildIPv4Packet(t,
|
||||
netip.MustParseAddr("10.0.0.1"),
|
||||
netip.MustParseAddr("10.0.0.2"),
|
||||
protoTCP, 12345, 80)
|
||||
|
||||
sess.Offer(pktMatch, true)
|
||||
sess.Offer(pktNoMatch, true)
|
||||
sess.Stop()
|
||||
|
||||
stats := sess.Stats()
|
||||
assert.Equal(t, int64(1), stats.Packets, "only matching packet should be captured")
|
||||
}
|
||||
|
||||
func TestSession_StopIdempotent(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
sess, err := NewSession(Options{Output: &buf})
|
||||
require.NoError(t, err)
|
||||
|
||||
sess.Stop()
|
||||
sess.Stop() // should not panic or deadlock
|
||||
}
|
||||
|
||||
func TestSession_OfferAfterStop(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
sess, err := NewSession(Options{Output: &buf})
|
||||
require.NoError(t, err)
|
||||
sess.Stop()
|
||||
|
||||
pkt := buildIPv4Packet(t,
|
||||
netip.MustParseAddr("10.0.0.1"),
|
||||
netip.MustParseAddr("10.0.0.2"),
|
||||
protoTCP, 12345, 443)
|
||||
sess.Offer(pkt, true) // should not panic
|
||||
|
||||
assert.Equal(t, int64(0), sess.Stats().Packets)
|
||||
}
|
||||
|
||||
func TestSession_Done(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
sess, err := NewSession(Options{Output: &buf})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-sess.Done():
|
||||
t.Fatal("Done should not be closed before Stop")
|
||||
default:
|
||||
}
|
||||
|
||||
sess.Stop()
|
||||
|
||||
select {
|
||||
case <-sess.Done():
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Done should be closed after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSession_RequiresOutput(t *testing.T) {
|
||||
_, err := NewSession(Options{})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
638
util/capture/text.go
Normal file
638
util/capture/text.go
Normal file
@@ -0,0 +1,638 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
// TextWriter writes human-readable one-line-per-packet summaries.
|
||||
// It is not safe for concurrent use; callers must serialize access.
|
||||
type TextWriter struct {
|
||||
w io.Writer
|
||||
verbose bool
|
||||
ascii bool
|
||||
flows map[dirKey]uint32
|
||||
}
|
||||
|
||||
type dirKey struct {
|
||||
src netip.AddrPort
|
||||
dst netip.AddrPort
|
||||
}
|
||||
|
||||
// NewTextWriter creates a text formatter that writes to w.
|
||||
func NewTextWriter(w io.Writer, verbose, ascii bool) *TextWriter {
|
||||
return &TextWriter{
|
||||
w: w,
|
||||
verbose: verbose,
|
||||
ascii: ascii,
|
||||
flows: make(map[dirKey]uint32),
|
||||
}
|
||||
}
|
||||
|
||||
// tag formats the fixed-width "[DIR PROTO]" prefix with right-aligned protocol.
|
||||
func tag(dir Direction, proto string) string {
|
||||
return fmt.Sprintf("[%-3s %4s]", dir, proto)
|
||||
}
|
||||
|
||||
// WritePacket formats and writes a single packet line.
|
||||
func (tw *TextWriter) WritePacket(ts time.Time, data []byte, dir Direction) error {
|
||||
ts = ts.Local()
|
||||
info, ok := parsePacketInfo(data)
|
||||
if !ok {
|
||||
_, err := fmt.Fprintf(tw.w, "%s [%-3s ?] ??? len=%d\n",
|
||||
ts.Format("15:04:05.000000"), dir, len(data))
|
||||
return err
|
||||
}
|
||||
|
||||
timeStr := ts.Format("15:04:05.000000")
|
||||
|
||||
var err error
|
||||
switch info.proto {
|
||||
case protoTCP:
|
||||
err = tw.writeTCP(timeStr, dir, &info, data)
|
||||
case protoUDP:
|
||||
err = tw.writeUDP(timeStr, dir, &info, data)
|
||||
case protoICMP:
|
||||
err = tw.writeICMPv4(timeStr, dir, &info, data)
|
||||
case protoICMPv6:
|
||||
err = tw.writeICMPv6(timeStr, dir, &info, data)
|
||||
default:
|
||||
var verbose string
|
||||
if tw.verbose {
|
||||
verbose = tw.verboseIP(data, info.family)
|
||||
}
|
||||
_, err = fmt.Fprintf(tw.w, "%s %s %s > %s length %d%s\n",
|
||||
timeStr, tag(dir, fmt.Sprintf("P%d", info.proto)),
|
||||
info.srcIP, info.dstIP, len(data)-info.hdrLen, verbose)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (tw *TextWriter) writeTCP(timeStr string, dir Direction, info *packetInfo, data []byte) error {
|
||||
tcp := &layers.TCP{}
|
||||
if err := tcp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil {
|
||||
return tw.writeFallback(timeStr, dir, "TCP", info, data)
|
||||
}
|
||||
|
||||
flags := tcpFlagsStr(tcp)
|
||||
plen := len(tcp.Payload)
|
||||
|
||||
// Protocol annotation
|
||||
var annotation string
|
||||
if plen > 0 {
|
||||
annotation = annotatePayload(tcp.Payload)
|
||||
}
|
||||
|
||||
if !tw.verbose {
|
||||
_, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d [%s] length %d%s\n",
|
||||
timeStr, tag(dir, "TCP"),
|
||||
info.srcIP, info.srcPort, info.dstIP, info.dstPort,
|
||||
flags, plen, annotation)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tw.ascii && plen > 0 {
|
||||
return tw.writeASCII(tcp.Payload)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
relSeq, relAck := tw.relativeSeqAck(info, tcp.Seq, tcp.Ack)
|
||||
|
||||
var seqStr string
|
||||
if plen > 0 {
|
||||
seqStr = fmt.Sprintf(", seq %d:%d", relSeq, relSeq+uint32(plen))
|
||||
} else {
|
||||
seqStr = fmt.Sprintf(", seq %d", relSeq)
|
||||
}
|
||||
|
||||
var ackStr string
|
||||
if tcp.ACK {
|
||||
ackStr = fmt.Sprintf(", ack %d", relAck)
|
||||
}
|
||||
|
||||
var opts string
|
||||
if s := formatTCPOptions(tcp.Options); s != "" {
|
||||
opts = ", options [" + s + "]"
|
||||
}
|
||||
|
||||
verbose := tw.verboseIP(data, info.family)
|
||||
|
||||
_, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d [%s]%s%s, win %d%s, length %d%s%s\n",
|
||||
timeStr, tag(dir, "TCP"),
|
||||
info.srcIP, info.srcPort, info.dstIP, info.dstPort,
|
||||
flags, seqStr, ackStr, tcp.Window, opts, plen, annotation, verbose)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tw.ascii && plen > 0 {
|
||||
return tw.writeASCII(tcp.Payload)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tw *TextWriter) writeUDP(timeStr string, dir Direction, info *packetInfo, data []byte) error {
|
||||
udp := &layers.UDP{}
|
||||
if err := udp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil {
|
||||
return tw.writeFallback(timeStr, dir, "UDP", info, data)
|
||||
}
|
||||
|
||||
plen := len(udp.Payload)
|
||||
|
||||
// DNS replaces the entire line format
|
||||
if plen > 0 && isDNSPort(info.srcPort, info.dstPort) {
|
||||
if s := formatDNSPayload(udp.Payload); s != "" {
|
||||
var verbose string
|
||||
if tw.verbose {
|
||||
verbose = tw.verboseIP(data, info.family)
|
||||
}
|
||||
_, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d %s%s\n",
|
||||
timeStr, tag(dir, "UDP"),
|
||||
info.srcIP, info.srcPort, info.dstIP, info.dstPort,
|
||||
s, verbose)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var verbose string
|
||||
if tw.verbose {
|
||||
verbose = tw.verboseIP(data, info.family)
|
||||
}
|
||||
_, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d length %d%s\n",
|
||||
timeStr, tag(dir, "UDP"),
|
||||
info.srcIP, info.srcPort, info.dstIP, info.dstPort,
|
||||
plen, verbose)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tw.ascii && plen > 0 {
|
||||
return tw.writeASCII(udp.Payload)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tw *TextWriter) writeICMPv4(timeStr string, dir Direction, info *packetInfo, data []byte) error {
|
||||
icmp := &layers.ICMPv4{}
|
||||
if err := icmp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil {
|
||||
return tw.writeFallback(timeStr, dir, "ICMP", info, data)
|
||||
}
|
||||
|
||||
var detail string
|
||||
if icmp.TypeCode.Type() == layers.ICMPv4TypeEchoRequest || icmp.TypeCode.Type() == layers.ICMPv4TypeEchoReply {
|
||||
detail = fmt.Sprintf("%s, id %d, seq %d", icmp.TypeCode.String(), icmp.Id, icmp.Seq)
|
||||
} else {
|
||||
detail = icmp.TypeCode.String()
|
||||
}
|
||||
|
||||
var verbose string
|
||||
if tw.verbose {
|
||||
verbose = tw.verboseIP(data, info.family)
|
||||
}
|
||||
_, err := fmt.Fprintf(tw.w, "%s %s %s > %s %s, length %d%s\n",
|
||||
timeStr, tag(dir, "ICMP"), info.srcIP, info.dstIP, detail, len(data)-info.hdrLen, verbose)
|
||||
return err
|
||||
}
|
||||
|
||||
func (tw *TextWriter) writeICMPv6(timeStr string, dir Direction, info *packetInfo, data []byte) error {
|
||||
icmp := &layers.ICMPv6{}
|
||||
if err := icmp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil {
|
||||
return tw.writeFallback(timeStr, dir, "ICMP", info, data)
|
||||
}
|
||||
|
||||
var verbose string
|
||||
if tw.verbose {
|
||||
verbose = tw.verboseIP(data, info.family)
|
||||
}
|
||||
_, err := fmt.Fprintf(tw.w, "%s %s %s > %s %s, length %d%s\n",
|
||||
timeStr, tag(dir, "ICMP"), info.srcIP, info.dstIP, icmp.TypeCode.String(), len(data)-info.hdrLen, verbose)
|
||||
return err
|
||||
}
|
||||
|
||||
func (tw *TextWriter) writeFallback(timeStr string, dir Direction, proto string, info *packetInfo, data []byte) error {
|
||||
_, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d length %d\n",
|
||||
timeStr, tag(dir, proto),
|
||||
info.srcIP, info.srcPort, info.dstIP, info.dstPort,
|
||||
len(data)-info.hdrLen)
|
||||
return err
|
||||
}
|
||||
|
||||
func (tw *TextWriter) verboseIP(data []byte, family uint8) string {
|
||||
return fmt.Sprintf(", ttl %d, id %d, iplen %d",
|
||||
ipTTL(data, family), ipID(data, family), len(data))
|
||||
}
|
||||
|
||||
// relativeSeqAck returns seq/ack relative to the first seen value per direction.
|
||||
func (tw *TextWriter) relativeSeqAck(info *packetInfo, seq, ack uint32) (relSeq, relAck uint32) {
|
||||
fwd := dirKey{
|
||||
src: netip.AddrPortFrom(info.srcIP, info.srcPort),
|
||||
dst: netip.AddrPortFrom(info.dstIP, info.dstPort),
|
||||
}
|
||||
rev := dirKey{
|
||||
src: netip.AddrPortFrom(info.dstIP, info.dstPort),
|
||||
dst: netip.AddrPortFrom(info.srcIP, info.srcPort),
|
||||
}
|
||||
|
||||
if isn, ok := tw.flows[fwd]; ok {
|
||||
relSeq = seq - isn
|
||||
} else {
|
||||
tw.flows[fwd] = seq
|
||||
}
|
||||
|
||||
if isn, ok := tw.flows[rev]; ok {
|
||||
relAck = ack - isn
|
||||
} else {
|
||||
relAck = ack
|
||||
}
|
||||
|
||||
return relSeq, relAck
|
||||
}
|
||||
|
||||
// writeASCII prints payload bytes as printable ASCII.
|
||||
func (tw *TextWriter) writeASCII(payload []byte) error {
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
buf := make([]byte, len(payload))
|
||||
for i, b := range payload {
|
||||
switch {
|
||||
case b >= 0x20 && b < 0x7f:
|
||||
buf[i] = b
|
||||
case b == '\n' || b == '\r' || b == '\t':
|
||||
buf[i] = b
|
||||
default:
|
||||
buf[i] = '.'
|
||||
}
|
||||
}
|
||||
_, err := fmt.Fprintf(tw.w, "%s\n", buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// --- TCP helpers ---
|
||||
|
||||
func ipTTL(data []byte, family uint8) uint8 {
|
||||
if family == 4 && len(data) > 8 {
|
||||
return data[8]
|
||||
}
|
||||
if family == 6 && len(data) > 7 {
|
||||
return data[7]
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func ipID(data []byte, family uint8) uint16 {
|
||||
if family == 4 && len(data) >= 6 {
|
||||
return binary.BigEndian.Uint16(data[4:6])
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func tcpFlagsStr(tcp *layers.TCP) string {
|
||||
var buf [6]byte
|
||||
n := 0
|
||||
if tcp.SYN {
|
||||
buf[n] = 'S'
|
||||
n++
|
||||
}
|
||||
if tcp.FIN {
|
||||
buf[n] = 'F'
|
||||
n++
|
||||
}
|
||||
if tcp.RST {
|
||||
buf[n] = 'R'
|
||||
n++
|
||||
}
|
||||
if tcp.PSH {
|
||||
buf[n] = 'P'
|
||||
n++
|
||||
}
|
||||
if tcp.ACK {
|
||||
buf[n] = '.'
|
||||
n++
|
||||
}
|
||||
if tcp.URG {
|
||||
buf[n] = 'U'
|
||||
n++
|
||||
}
|
||||
if n == 0 {
|
||||
return "none"
|
||||
}
|
||||
return string(buf[:n])
|
||||
}
|
||||
|
||||
func formatTCPOptions(opts []layers.TCPOption) string {
|
||||
var parts []string
|
||||
for _, opt := range opts {
|
||||
switch opt.OptionType {
|
||||
case layers.TCPOptionKindEndList:
|
||||
return strings.Join(parts, ",")
|
||||
case layers.TCPOptionKindNop:
|
||||
parts = append(parts, "nop")
|
||||
case layers.TCPOptionKindMSS:
|
||||
if len(opt.OptionData) == 2 {
|
||||
parts = append(parts, fmt.Sprintf("mss %d", binary.BigEndian.Uint16(opt.OptionData)))
|
||||
}
|
||||
case layers.TCPOptionKindWindowScale:
|
||||
if len(opt.OptionData) == 1 {
|
||||
parts = append(parts, fmt.Sprintf("wscale %d", opt.OptionData[0]))
|
||||
}
|
||||
case layers.TCPOptionKindSACKPermitted:
|
||||
parts = append(parts, "sackOK")
|
||||
case layers.TCPOptionKindSACK:
|
||||
blocks := len(opt.OptionData) / 8
|
||||
parts = append(parts, fmt.Sprintf("sack %d", blocks))
|
||||
case layers.TCPOptionKindTimestamps:
|
||||
if len(opt.OptionData) == 8 {
|
||||
tsval := binary.BigEndian.Uint32(opt.OptionData[0:4])
|
||||
tsecr := binary.BigEndian.Uint32(opt.OptionData[4:8])
|
||||
parts = append(parts, fmt.Sprintf("TS val %d ecr %d", tsval, tsecr))
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
// --- Protocol annotation ---
|
||||
|
||||
// annotatePayload returns a protocol annotation string for known application protocols.
|
||||
func annotatePayload(payload []byte) string {
|
||||
if len(payload) < 4 {
|
||||
return ""
|
||||
}
|
||||
|
||||
s := string(payload)
|
||||
|
||||
// SSH banner: "SSH-2.0-OpenSSH_9.6\r\n"
|
||||
if strings.HasPrefix(s, "SSH-") {
|
||||
if end := strings.IndexByte(s, '\r'); end > 0 && end < 256 {
|
||||
return ": " + s[:end]
|
||||
}
|
||||
}
|
||||
|
||||
// TLS records
|
||||
if ann := annotateTLS(payload); ann != "" {
|
||||
return ": " + ann
|
||||
}
|
||||
|
||||
// HTTP request or response
|
||||
for _, method := range [...]string{"GET ", "POST ", "PUT ", "DELETE ", "HEAD ", "PATCH ", "OPTIONS ", "CONNECT "} {
|
||||
if strings.HasPrefix(s, method) {
|
||||
if end := strings.IndexByte(s, '\r'); end > 0 && end < 200 {
|
||||
return ": " + s[:end]
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(s, "HTTP/") {
|
||||
if end := strings.IndexByte(s, '\r'); end > 0 && end < 200 {
|
||||
return ": " + s[:end]
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// annotateTLS returns a description for TLS handshake and alert records.
|
||||
func annotateTLS(data []byte) string {
|
||||
if len(data) < 6 {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch data[0] {
|
||||
case 0x16:
|
||||
return annotateTLSHandshake(data)
|
||||
case 0x15:
|
||||
return annotateTLSAlert(data)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func annotateTLSHandshake(data []byte) string {
|
||||
if len(data) < 10 {
|
||||
return ""
|
||||
}
|
||||
switch data[5] {
|
||||
case 0x01:
|
||||
if sni := extractSNI(data); sni != "" {
|
||||
return "TLS ClientHello SNI=" + sni
|
||||
}
|
||||
return "TLS ClientHello"
|
||||
case 0x02:
|
||||
return "TLS ServerHello"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func annotateTLSAlert(data []byte) string {
|
||||
if len(data) < 7 {
|
||||
return ""
|
||||
}
|
||||
severity := "warning"
|
||||
if data[5] == 2 {
|
||||
severity = "fatal"
|
||||
}
|
||||
return fmt.Sprintf("TLS Alert %s %s", severity, tlsAlertDesc(data[6]))
|
||||
}
|
||||
|
||||
func tlsAlertDesc(code byte) string {
|
||||
switch code {
|
||||
case 0:
|
||||
return "close_notify"
|
||||
case 10:
|
||||
return "unexpected_message"
|
||||
case 40:
|
||||
return "handshake_failure"
|
||||
case 42:
|
||||
return "bad_certificate"
|
||||
case 43:
|
||||
return "unsupported_certificate"
|
||||
case 44:
|
||||
return "certificate_revoked"
|
||||
case 45:
|
||||
return "certificate_expired"
|
||||
case 48:
|
||||
return "unknown_ca"
|
||||
case 49:
|
||||
return "access_denied"
|
||||
case 50:
|
||||
return "decode_error"
|
||||
case 70:
|
||||
return "protocol_version"
|
||||
case 80:
|
||||
return "internal_error"
|
||||
case 86:
|
||||
return "inappropriate_fallback"
|
||||
case 90:
|
||||
return "user_canceled"
|
||||
case 112:
|
||||
return "unrecognized_name"
|
||||
default:
|
||||
return fmt.Sprintf("alert(%d)", code)
|
||||
}
|
||||
}
|
||||
|
||||
// extractSNI parses a TLS ClientHello and returns the SNI server name.
|
||||
func extractSNI(data []byte) string {
|
||||
if len(data) < 6 || data[0] != 0x16 {
|
||||
return ""
|
||||
}
|
||||
recordLen := int(binary.BigEndian.Uint16(data[3:5]))
|
||||
handshake := data[5:]
|
||||
if len(handshake) > recordLen {
|
||||
handshake = handshake[:recordLen]
|
||||
}
|
||||
|
||||
if len(handshake) < 4 || handshake[0] != 0x01 {
|
||||
return ""
|
||||
}
|
||||
hsLen := int(handshake[1])<<16 | int(handshake[2])<<8 | int(handshake[3])
|
||||
body := handshake[4:]
|
||||
if len(body) > hsLen {
|
||||
body = body[:hsLen]
|
||||
}
|
||||
|
||||
extPos := clientHelloExtensionsOffset(body)
|
||||
if extPos < 0 {
|
||||
return ""
|
||||
}
|
||||
return findSNIExtension(body, extPos)
|
||||
}
|
||||
|
||||
// clientHelloExtensionsOffset returns the byte offset where extensions begin
|
||||
// within the ClientHello body, or -1 if the body is too short.
|
||||
func clientHelloExtensionsOffset(body []byte) int {
|
||||
if len(body) < 38 {
|
||||
return -1
|
||||
}
|
||||
pos := 34
|
||||
|
||||
if pos >= len(body) {
|
||||
return -1
|
||||
}
|
||||
pos += 1 + int(body[pos]) // session ID
|
||||
|
||||
if pos+2 > len(body) {
|
||||
return -1
|
||||
}
|
||||
pos += 2 + int(binary.BigEndian.Uint16(body[pos:pos+2])) // cipher suites
|
||||
|
||||
if pos >= len(body) {
|
||||
return -1
|
||||
}
|
||||
pos += 1 + int(body[pos]) // compression methods
|
||||
|
||||
if pos+2 > len(body) {
|
||||
return -1
|
||||
}
|
||||
return pos
|
||||
}
|
||||
|
||||
func findSNIExtension(body []byte, pos int) string {
|
||||
extLen := int(binary.BigEndian.Uint16(body[pos : pos+2]))
|
||||
pos += 2
|
||||
extEnd := pos + extLen
|
||||
if extEnd > len(body) {
|
||||
extEnd = len(body)
|
||||
}
|
||||
|
||||
for pos+4 <= extEnd {
|
||||
extType := binary.BigEndian.Uint16(body[pos : pos+2])
|
||||
eLen := int(binary.BigEndian.Uint16(body[pos+2 : pos+4]))
|
||||
pos += 4
|
||||
if pos+eLen > extEnd {
|
||||
break
|
||||
}
|
||||
if extType == 0 && eLen >= 5 {
|
||||
nameLen := int(binary.BigEndian.Uint16(body[pos+3 : pos+5]))
|
||||
if pos+5+nameLen <= extEnd {
|
||||
return string(body[pos+5 : pos+5+nameLen])
|
||||
}
|
||||
}
|
||||
pos += eLen
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func isDNSPort(src, dst uint16) bool {
|
||||
return src == 53 || dst == 53 || src == 5353 || dst == 5353
|
||||
}
|
||||
|
||||
// formatDNSPayload parses DNS and returns a tcpdump-style summary.
|
||||
func formatDNSPayload(payload []byte) string {
|
||||
d := &layers.DNS{}
|
||||
if err := d.DecodeFromBytes(payload, gopacket.NilDecodeFeedback); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
rd := ""
|
||||
if d.RD {
|
||||
rd = "+"
|
||||
}
|
||||
|
||||
if !d.QR {
|
||||
return formatDNSQuery(d, rd, len(payload))
|
||||
}
|
||||
return formatDNSResponse(d, rd, len(payload))
|
||||
}
|
||||
|
||||
func formatDNSQuery(d *layers.DNS, rd string, plen int) string {
|
||||
if len(d.Questions) == 0 {
|
||||
return fmt.Sprintf("%04x%s (%d)", d.ID, rd, plen)
|
||||
}
|
||||
q := d.Questions[0]
|
||||
return fmt.Sprintf("%04x%s %s? %s. (%d)", d.ID, rd, q.Type, q.Name, plen)
|
||||
}
|
||||
|
||||
func formatDNSResponse(d *layers.DNS, rd string, plen int) string {
|
||||
anCount := d.ANCount
|
||||
nsCount := d.NSCount
|
||||
arCount := d.ARCount
|
||||
|
||||
if d.ResponseCode != layers.DNSResponseCodeNoErr {
|
||||
return fmt.Sprintf("%04x %d/%d/%d %s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, plen)
|
||||
}
|
||||
|
||||
if anCount > 0 && len(d.Answers) > 0 {
|
||||
rr := d.Answers[0]
|
||||
if rdata := shortRData(&rr); rdata != "" {
|
||||
return fmt.Sprintf("%04x %d/%d/%d %s %s (%d)", d.ID, anCount, nsCount, arCount, rr.Type, rdata, plen)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%04x %d/%d/%d (%d)", d.ID, anCount, nsCount, arCount, plen)
|
||||
}
|
||||
|
||||
func shortRData(rr *layers.DNSResourceRecord) string {
|
||||
switch rr.Type {
|
||||
case layers.DNSTypeA, layers.DNSTypeAAAA:
|
||||
if rr.IP != nil {
|
||||
return rr.IP.String()
|
||||
}
|
||||
case layers.DNSTypeCNAME:
|
||||
if len(rr.CNAME) > 0 {
|
||||
return string(rr.CNAME) + "."
|
||||
}
|
||||
case layers.DNSTypePTR:
|
||||
if len(rr.PTR) > 0 {
|
||||
return string(rr.PTR) + "."
|
||||
}
|
||||
case layers.DNSTypeNS:
|
||||
if len(rr.NS) > 0 {
|
||||
return string(rr.NS) + "."
|
||||
}
|
||||
case layers.DNSTypeMX:
|
||||
return fmt.Sprintf("%d %s.", rr.MX.Preference, rr.MX.Name)
|
||||
case layers.DNSTypeTXT:
|
||||
if len(rr.TXTs) > 0 {
|
||||
return fmt.Sprintf("%q", string(rr.TXTs[0]))
|
||||
}
|
||||
case layers.DNSTypeSRV:
|
||||
return fmt.Sprintf("%d %d %d %s.", rr.SRV.Priority, rr.SRV.Weight, rr.SRV.Port, rr.SRV.Name)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
Reference in New Issue
Block a user