diff --git a/client/Dockerfile b/client/Dockerfile index 64d5ba04f..53e4555ef 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -17,6 +17,7 @@ ENV \ NETBIRD_BIN="/usr/local/bin/netbird" \ NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ + NB_ENABLE_CAPTURE="false" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="30" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/Dockerfile-rootless b/client/Dockerfile-rootless index 69d00aaf2..706bf40de 100644 --- a/client/Dockerfile-rootless +++ b/client/Dockerfile-rootless @@ -23,6 +23,7 @@ ENV \ NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \ NB_LOG_FILE="console,/var/lib/netbird/client.log" \ NB_DISABLE_DNS="true" \ + NB_ENABLE_CAPTURE="false" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="30" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/cmd/capture.go b/client/cmd/capture.go new file mode 100644 index 000000000..70522fdbb --- /dev/null +++ b/client/cmd/capture.go @@ -0,0 +1,186 @@ +package cmd + +import ( + "context" + "fmt" + "io" + "os" + "os/signal" + "path/filepath" + "strings" + "syscall" + + "github.com/spf13/cobra" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/util/capture" +) + +var captureCmd = &cobra.Command{ + Use: "capture", + Short: "Capture packets on the WireGuard interface", + Long: `Captures decrypted packets flowing through the WireGuard interface. + +Default output is human-readable text. Use --pcap or --output for pcap binary. +Requires --enable-capture to be set at service install or reconfigure time. + +Examples: + netbird debug capture + netbird debug capture host 100.64.0.1 and port 443 + netbird debug capture tcp + netbird debug capture icmp + netbird debug capture src host 10.0.0.1 and dst port 80 + netbird debug capture -o capture.pcap + netbird debug capture --pcap | tshark -r - + netbird debug capture --pcap | tcpdump -r - -n`, + Args: cobra.ArbitraryArgs, + RunE: runCapture, +} + +func init() { + debugCmd.AddCommand(captureCmd) + + captureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)") + captureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length") + captureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (useful for HTTP)") + captureCmd.Flags().Uint32("snap-len", 0, "Max bytes per packet (0 = full)") + captureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = until interrupted)") + captureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout") +} + +func runCapture(cmd *cobra.Command, args []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer func() { + if err := conn.Close(); err != nil { + cmd.PrintErrf(errCloseConnection, err) + } + }() + + client := proto.NewDaemonServiceClient(conn) + + req, err := buildCaptureRequest(cmd, args) + if err != nil { + return err + } + + ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + stream, err := client.StartCapture(ctx, req) + if err != nil { + return handleCaptureError(err) + } + + // First Recv is the empty acceptance message from the server. If the + // device is unavailable (kernel WG, not connected, capture disabled), + // the server returns an error instead. + if _, err := stream.Recv(); err != nil { + return handleCaptureError(err) + } + + out, cleanup, err := captureOutput(cmd) + if err != nil { + return err + } + defer cleanup() + + if req.TextOutput { + cmd.PrintErrf("Capturing packets... Press Ctrl+C to stop.\n") + } else { + cmd.PrintErrf("Capturing packets (pcap)... Press Ctrl+C to stop.\n") + } + + return streamCapture(ctx, cmd, stream, out) +} + +func buildCaptureRequest(cmd *cobra.Command, args []string) (*proto.StartCaptureRequest, error) { + req := &proto.StartCaptureRequest{} + + if len(args) > 0 { + expr := strings.Join(args, " ") + if _, err := capture.ParseFilter(expr); err != nil { + return nil, fmt.Errorf("invalid filter: %w", err) + } + req.FilterExpr = expr + } + + if snap, _ := cmd.Flags().GetUint32("snap-len"); snap > 0 { + req.SnapLen = snap + } + if d, _ := cmd.Flags().GetDuration("duration"); d != 0 { + if d < 0 { + return nil, fmt.Errorf("duration must not be negative") + } + req.Duration = durationpb.New(d) + } + req.Verbose, _ = cmd.Flags().GetBool("verbose") + req.Ascii, _ = cmd.Flags().GetBool("ascii") + + outPath, _ := cmd.Flags().GetString("output") + forcePcap, _ := cmd.Flags().GetBool("pcap") + req.TextOutput = !forcePcap && outPath == "" + + return req, nil +} + +func streamCapture(ctx context.Context, cmd *cobra.Command, stream proto.DaemonService_StartCaptureClient, out io.Writer) error { + for { + pkt, err := stream.Recv() + if err != nil { + if ctx.Err() != nil { + cmd.PrintErrf("\nCapture stopped.\n") + return nil //nolint:nilerr // user interrupted + } + if err == io.EOF { + cmd.PrintErrf("\nCapture finished.\n") + return nil + } + return handleCaptureError(err) + } + if _, err := out.Write(pkt.GetData()); err != nil { + return fmt.Errorf("write output: %w", err) + } + } +} + +// captureOutput returns the writer for capture data and a cleanup function. +func captureOutput(cmd *cobra.Command) (io.Writer, func(), error) { + outPath, _ := cmd.Flags().GetString("output") + if outPath == "" { + return os.Stdout, func() { + // no cleanup needed for stdout + }, nil + } + + f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp") + if err != nil { + return nil, nil, fmt.Errorf("create output file: %w", err) + } + tmpPath := f.Name() + return f, func() { + if err := f.Close(); err != nil { + cmd.PrintErrf("close output file: %v\n", err) + } + if fi, err := os.Stat(tmpPath); err == nil && fi.Size() > 0 { + if err := os.Rename(tmpPath, outPath); err != nil { + cmd.PrintErrf("rename output file: %v\n", err) + } else { + cmd.PrintErrf("Wrote %s\n", outPath) + } + } else { + os.Remove(tmpPath) + } + }, nil +} + +func handleCaptureError(err error) error { + if s, ok := status.FromError(err); ok { + return fmt.Errorf("%s", s.Message()) + } + return err +} diff --git a/client/cmd/debug.go b/client/cmd/debug.go index e3d3afe5f..2a8cdc887 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -9,6 +9,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/debug" @@ -239,11 +240,50 @@ func runForDuration(cmd *cobra.Command, args []string) error { }() } + captureStarted := false + if wantCapture, _ := cmd.Flags().GetBool("capture"); wantCapture { + captureTimeout := duration + 30*time.Second + const maxBundleCapture = 10 * time.Minute + if captureTimeout > maxBundleCapture { + captureTimeout = maxBundleCapture + } + _, err := client.StartBundleCapture(cmd.Context(), &proto.StartBundleCaptureRequest{ + Timeout: durationpb.New(captureTimeout), + }) + if err != nil { + cmd.PrintErrf("Failed to start packet capture: %v\n", status.Convert(err).Message()) + } else { + captureStarted = true + cmd.Println("Packet capture started.") + // Safety: always stop on exit, even if the normal stop below runs too. + defer func() { + if captureStarted { + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil { + cmd.PrintErrf("Failed to stop packet capture: %v\n", err) + } + } + }() + } + } + if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil { return waitErr } cmd.Println("\nDuration completed") + if captureStarted { + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil { + cmd.PrintErrf("Failed to stop packet capture: %v\n", err) + } else { + captureStarted = false + cmd.Println("Packet capture stopped.") + } + } + if cpuProfilingStarted { if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil { cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err) @@ -416,4 +456,5 @@ func init() { forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle") forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server") forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") + forCmd.Flags().Bool("capture", false, "Capture packets during the debug duration and include in bundle") } diff --git a/client/cmd/root.go b/client/cmd/root.go index aa5b98dfd..dc21abaca 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -75,6 +75,7 @@ var ( mtu uint16 profilesDisabled bool updateSettingsDisabled bool + captureEnabled bool rootCmd = &cobra.Command{ Use: "netbird", diff --git a/client/cmd/service.go b/client/cmd/service.go index 5ff16eaeb..5addb644a 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -44,6 +44,7 @@ func init() { serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd) serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles") serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings") + serviceCmd.PersistentFlags().BoolVar(&captureEnabled, "enable-capture", false, "Enables packet capture via 'netbird debug capture'. To persist, use: netbird service install --enable-capture") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") serviceEnvDesc := `Sets extra environment variables for the service. ` + diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 5fe318ddf..0a2b14f35 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error { } } - serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled) + serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, captureEnabled) if err := serverInstance.Start(); err != nil { log.Fatalf("failed to start daemon: %v", err) } diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index 28770ea16..e3011b5e9 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -59,6 +59,10 @@ func buildServiceArguments() []string { args = append(args, "--disable-update-settings") } + if captureEnabled { + args = append(args, "--enable-capture") + } + return args } diff --git a/client/cmd/service_params.go b/client/cmd/service_params.go index 81bd2dbb5..8bb22045e 100644 --- a/client/cmd/service_params.go +++ b/client/cmd/service_params.go @@ -28,6 +28,7 @@ type serviceParams struct { LogFiles []string `json:"log_files,omitempty"` DisableProfiles bool `json:"disable_profiles,omitempty"` DisableUpdateSettings bool `json:"disable_update_settings,omitempty"` + EnableCapture bool `json:"enable_capture,omitempty"` ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"` } @@ -78,6 +79,7 @@ func currentServiceParams() *serviceParams { LogFiles: logFiles, DisableProfiles: profilesDisabled, DisableUpdateSettings: updateSettingsDisabled, + EnableCapture: captureEnabled, } if len(serviceEnvVars) > 0 { @@ -142,6 +144,10 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) { updateSettingsDisabled = params.DisableUpdateSettings } + if !serviceCmd.PersistentFlags().Changed("enable-capture") { + captureEnabled = params.EnableCapture + } + applyServiceEnvParams(cmd, params) } diff --git a/client/cmd/service_params_test.go b/client/cmd/service_params_test.go index 3bc8e4f60..0c2f699cd 100644 --- a/client/cmd/service_params_test.go +++ b/client/cmd/service_params_test.go @@ -500,6 +500,7 @@ func fieldToGlobalVar(field string) string { "LogFiles": "logFiles", "DisableProfiles": "profilesDisabled", "DisableUpdateSettings": "updateSettingsDisabled", + "EnableCapture": "captureEnabled", "ServiceEnvVars": "serviceEnvVars", } if v, ok := m[field]; ok { diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 4bda33e65..5c6926f04 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -152,7 +152,7 @@ func startClientDaemon( s := grpc.NewServer() server := client.New(ctx, - "", "", false, false) + "", "", false, false, false) if err := server.Start(); err != nil { t.Fatal(err) } diff --git a/client/embed/embed.go b/client/embed/embed.go index 88f7e541c..f9a22718e 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/util/capture" ) var ( @@ -489,6 +490,22 @@ func (c *Client) getEngine() (*internal.Engine, error) { return engine, nil } +// SetCapture sets or clears packet capture on this client's WireGuard device. +// Pass nil to disable capture. +func (c *Client) SetCapture(sess *capture.Session) error { + engine, err := c.getEngine() + if err != nil { + return err + } + // Explicit nil check to avoid wrapping a nil *Session in the + // device.PacketCapture interface, which would appear non-nil to + // FilteredDevice and cause a nil-pointer dereference in Offer. + if sess == nil { + return engine.SetCapture(nil) + } + return engine.SetCapture(sess) +} + func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) { engine, err := c.getEngine() if err != nil { diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 24b3d0167..ebb11b2d4 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -115,12 +115,13 @@ type Manager struct { localipmanager *localIPManager - udpTracker *conntrack.UDPTracker - icmpTracker *conntrack.ICMPTracker - tcpTracker *conntrack.TCPTracker - forwarder atomic.Pointer[forwarder.Forwarder] - logger *nblog.Logger - flowLogger nftypes.FlowLogger + udpTracker *conntrack.UDPTracker + icmpTracker *conntrack.ICMPTracker + tcpTracker *conntrack.TCPTracker + forwarder atomic.Pointer[forwarder.Forwarder] + pendingCapture atomic.Pointer[forwarder.PacketCapture] + logger *nblog.Logger + flowLogger nftypes.FlowLogger blockRule firewall.Rule @@ -351,6 +352,19 @@ func (m *Manager) determineRouting() error { return nil } +// SetPacketCapture sets or clears packet capture on the forwarder endpoint. +// This captures outbound response packets that bypass the FilteredDevice in netstack mode. +func (m *Manager) SetPacketCapture(pc forwarder.PacketCapture) { + if pc == nil { + m.pendingCapture.Store(nil) + } else { + m.pendingCapture.Store(&pc) + } + if fwder := m.forwarder.Load(); fwder != nil { + fwder.SetCapture(pc) + } +} + // initForwarder initializes the forwarder, it disables routing on errors func (m *Manager) initForwarder() error { if m.forwarder.Load() != nil { @@ -370,6 +384,10 @@ func (m *Manager) initForwarder() error { return fmt.Errorf("create forwarder: %w", err) } + if pc := m.pendingCapture.Load(); pc != nil { + forwarder.SetCapture(*pc) + } + m.forwarder.Store(forwarder) log.Debug("forwarder initialized") @@ -614,6 +632,7 @@ func (m *Manager) resetState() { } if fwder := m.forwarder.Load(); fwder != nil { + fwder.SetCapture(nil) fwder.Stop() } diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go index 692a24140..96ab89af8 100644 --- a/client/firewall/uspfilter/forwarder/endpoint.go +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -12,12 +12,19 @@ import ( nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) +// PacketCapture captures raw packets for debugging. Implementations must be +// safe for concurrent use and must not block. +type PacketCapture interface { + Offer(data []byte, outbound bool) +} + // endpoint implements stack.LinkEndpoint and handles integration with the wireguard device type endpoint struct { logger *nblog.Logger dispatcher stack.NetworkDispatcher device *wgdevice.Device mtu atomic.Uint32 + capture atomic.Pointer[PacketCapture] } func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { @@ -54,13 +61,17 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) continue } - // Send the packet through WireGuard + pktBytes := data.AsSlice() + address := netHeader.DestinationAddress() - err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()) - if err != nil { + if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil { e.logger.Error1("CreateOutboundPacket: %v", err) continue } + + if pc := e.capture.Load(); pc != nil { + (*pc).Offer(pktBytes, true) + } written++ } diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index d17c3cd5c..925273f24 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -139,6 +139,16 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow return f, nil } +// SetCapture sets or clears the packet capture on the forwarder endpoint. +// This captures outbound packets that bypass the FilteredDevice (netstack forwarding). +func (f *Forwarder) SetCapture(pc PacketCapture) { + if pc == nil { + f.endpoint.capture.Store(nil) + return + } + f.endpoint.capture.Store(&pc) +} + func (f *Forwarder) InjectIncomingPacket(payload []byte) error { if len(payload) < header.IPv4MinimumSize { return fmt.Errorf("packet too small: %d bytes", len(payload)) diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index cb3db325d..217423901 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -270,5 +270,9 @@ func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload [] return 0 } + if pc := f.endpoint.capture.Load(); pc != nil { + (*pc).Offer(fullPacket, true) + } + return len(fullPacket) } diff --git a/client/iface/device/device_filter.go b/client/iface/device/device_filter.go index 4357d1916..fc1c65efa 100644 --- a/client/iface/device/device_filter.go +++ b/client/iface/device/device_filter.go @@ -3,6 +3,7 @@ package device import ( "net/netip" "sync" + "sync/atomic" "golang.zx2c4.com/wireguard/tun" ) @@ -28,11 +29,20 @@ type PacketFilter interface { SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) } +// PacketCapture captures raw packets for debugging. Implementations must be +// safe for concurrent use and must not block. +type PacketCapture interface { + // Offer submits a packet for capture. outbound is true for packets + // leaving the host (Read path), false for packets arriving (Write path). + Offer(data []byte, outbound bool) +} + // FilteredDevice to override Read or Write of packets type FilteredDevice struct { tun.Device filter PacketFilter + capture atomic.Pointer[PacketCapture] mutex sync.RWMutex closeOnce sync.Once } @@ -63,20 +73,25 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er if n, err = d.Device.Read(bufs, sizes, offset); err != nil { return 0, err } + d.mutex.RLock() filter := d.filter d.mutex.RUnlock() - if filter == nil { - return + if filter != nil { + for i := 0; i < n; i++ { + if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) { + bufs = append(bufs[:i], bufs[i+1:]...) + sizes = append(sizes[:i], sizes[i+1:]...) + n-- + i-- + } + } } - for i := 0; i < n; i++ { - if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) { - bufs = append(bufs[:i], bufs[i+1:]...) - sizes = append(sizes[:i], sizes[i+1:]...) - n-- - i-- + if pc := d.capture.Load(); pc != nil { + for i := 0; i < n; i++ { + (*pc).Offer(bufs[i][offset:offset+sizes[i]], true) } } @@ -85,6 +100,13 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er // Write wraps write method with filtering feature func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { + // Capture before filtering so dropped packets are still visible in captures. + if pc := d.capture.Load(); pc != nil { + for _, buf := range bufs { + (*pc).Offer(buf[offset:], false) + } + } + d.mutex.RLock() filter := d.filter d.mutex.RUnlock() @@ -96,9 +118,10 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { filteredBufs := make([][]byte, 0, len(bufs)) dropped := 0 for _, buf := range bufs { - if !filter.FilterInbound(buf[offset:], len(buf)) { - filteredBufs = append(filteredBufs, buf) + if filter.FilterInbound(buf[offset:], len(buf)) { dropped++ + } else { + filteredBufs = append(filteredBufs, buf) } } @@ -113,3 +136,14 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) { d.filter = filter d.mutex.Unlock() } + +// SetCapture sets or clears the packet capture sink. Pass nil to disable. +// Uses atomic store so the hot path (Read/Write) is a single pointer load +// with no locking overhead when capture is off. +func (d *FilteredDevice) SetCapture(pc PacketCapture) { + if pc == nil { + d.capture.Store(nil) + return + } + d.capture.Store(&pc) +} diff --git a/client/iface/device/device_filter_test.go b/client/iface/device/device_filter_test.go index eef783542..8fb16ca8d 100644 --- a/client/iface/device/device_filter_test.go +++ b/client/iface/device/device_filter_test.go @@ -158,7 +158,7 @@ func TestDeviceWrapperRead(t *testing.T) { t.Errorf("unexpected error: %v", err) return } - if n != 0 { + if n != 1 { t.Errorf("expected n=1, got %d", n) return } diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 6a8eae324..bad481519 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -63,6 +63,7 @@ allocs.prof: Allocations profiling information. threadcreate.prof: Thread creation profiling information. cpu.prof: CPU profiling information. stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation. +capture.pcap: Packet capture in pcap format. Only present when capture was running during bundle collection. Omitted from anonymized bundles because it contains raw decrypted packet data. Anonymization Process @@ -235,6 +236,7 @@ type BundleGenerator struct { syncResponse *mgmProto.SyncResponse logPath string cpuProfile []byte + capturePath string refreshStatus func() // Optional callback to refresh status before bundle generation clientMetrics MetricsExporter @@ -257,7 +259,8 @@ type GeneratorDependencies struct { SyncResponse *mgmProto.SyncResponse LogPath string CPUProfile []byte - RefreshStatus func() // Optional callback to refresh status before bundle generation + CapturePath string + RefreshStatus func() ClientMetrics MetricsExporter } @@ -276,6 +279,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen syncResponse: deps.SyncResponse, logPath: deps.LogPath, cpuProfile: deps.CPUProfile, + capturePath: deps.CapturePath, refreshStatus: deps.RefreshStatus, clientMetrics: deps.ClientMetrics, @@ -345,6 +349,10 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add CPU profile to debug bundle: %v", err) } + if err := g.addCaptureFile(); err != nil { + log.Errorf("failed to add capture file to debug bundle: %v", err) + } + if err := g.addStackTrace(); err != nil { log.Errorf("failed to add stack trace to debug bundle: %v", err) } @@ -675,6 +683,29 @@ func (g *BundleGenerator) addCPUProfile() error { return nil } +func (g *BundleGenerator) addCaptureFile() error { + if g.capturePath == "" { + return nil + } + + if g.anonymize { + log.Info("skipping capture file in anonymized bundle (contains raw packet data)") + return nil + } + + f, err := os.Open(g.capturePath) + if err != nil { + return fmt.Errorf("open capture file: %w", err) + } + defer f.Close() + + if err := g.addFileToZip(f, "capture.pcap"); err != nil { + return fmt.Errorf("add capture file to zip: %w", err) + } + + return nil +} + func (g *BundleGenerator) addStackTrace() error { buf := make([]byte, 5242880) // 5 MB buffer n := runtime.Stack(buf, true) diff --git a/client/internal/engine.go b/client/internal/engine.go index be2d8bbf3..ef643872f 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -27,6 +27,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/firewall" firewallManager "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" @@ -67,6 +68,7 @@ import ( signal "github.com/netbirdio/netbird/shared/signal/client" sProto "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/capture" ) // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. @@ -216,6 +218,8 @@ type Engine struct { portForwardManager *portforward.Manager srWatcher *guard.SRWatcher + afpacketCapture *capture.AFPacketCapture + // Sync response persistence (protected by syncRespMux) syncRespMux sync.RWMutex persistSyncResponse bool @@ -1693,6 +1697,11 @@ func (e *Engine) parseNATExternalIPMappings() []string { } func (e *Engine) close() { + if e.afpacketCapture != nil { + e.afpacketCapture.Stop() + e.afpacketCapture = nil + } + log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) if e.wgInterface != nil { @@ -2158,6 +2167,62 @@ func (e *Engine) Address() (netip.Addr, error) { return e.wgInterface.Address().IP, nil } +// SetCapture sets or clears packet capture on the WireGuard device. +// On userspace WireGuard, it taps the FilteredDevice directly. +// On kernel WireGuard (Linux), it falls back to AF_PACKET raw socket capture. +// Pass nil to disable capture. +func (e *Engine) SetCapture(pc device.PacketCapture) error { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + + intf := e.wgInterface + if intf == nil { + return errors.New("wireguard interface not initialized") + } + + if e.afpacketCapture != nil { + e.afpacketCapture.Stop() + e.afpacketCapture = nil + } + + dev := intf.GetDevice() + if dev != nil { + dev.SetCapture(pc) + e.setForwarderCapture(pc) + return nil + } + + // Kernel mode: no FilteredDevice. Use AF_PACKET on Linux. + if pc == nil { + return nil + } + sess, ok := pc.(*capture.Session) + if !ok { + return errors.New("filtered device not available and AF_PACKET requires *capture.Session") + } + + afc := capture.NewAFPacketCapture(intf.Name(), sess) + if err := afc.Start(); err != nil { + return fmt.Errorf("start AF_PACKET capture on %s: %w", intf.Name(), err) + } + e.afpacketCapture = afc + return nil +} + +// setForwarderCapture propagates capture to the USP filter's forwarder endpoint. +// This captures outbound response packets that bypass the FilteredDevice in netstack mode. +func (e *Engine) setForwarderCapture(pc device.PacketCapture) { + if e.firewall == nil { + return + } + type forwarderCapturer interface { + SetPacketCapture(pc forwarder.PacketCapture) + } + if fc, ok := e.firewall.(forwarderCapturer); ok { + fc.SetPacketCapture(pc) + } +} + func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) { if e.firewall == nil { log.Warn("firewall is disabled, not updating forwarding rules") diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index fa0b2f93b..47283e216 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -5969,6 +5969,288 @@ func (x *ExposeServiceReady) GetPortAutoAssigned() bool { return false } +type StartCaptureRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + TextOutput bool `protobuf:"varint,1,opt,name=text_output,json=textOutput,proto3" json:"text_output,omitempty"` + SnapLen uint32 `protobuf:"varint,2,opt,name=snap_len,json=snapLen,proto3" json:"snap_len,omitempty"` + Duration *durationpb.Duration `protobuf:"bytes,3,opt,name=duration,proto3" json:"duration,omitempty"` + FilterExpr string `protobuf:"bytes,4,opt,name=filter_expr,json=filterExpr,proto3" json:"filter_expr,omitempty"` + Verbose bool `protobuf:"varint,5,opt,name=verbose,proto3" json:"verbose,omitempty"` + Ascii bool `protobuf:"varint,6,opt,name=ascii,proto3" json:"ascii,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartCaptureRequest) Reset() { + *x = StartCaptureRequest{} + mi := &file_daemon_proto_msgTypes[90] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartCaptureRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartCaptureRequest) ProtoMessage() {} + +func (x *StartCaptureRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[90] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartCaptureRequest.ProtoReflect.Descriptor instead. +func (*StartCaptureRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{90} +} + +func (x *StartCaptureRequest) GetTextOutput() bool { + if x != nil { + return x.TextOutput + } + return false +} + +func (x *StartCaptureRequest) GetSnapLen() uint32 { + if x != nil { + return x.SnapLen + } + return 0 +} + +func (x *StartCaptureRequest) GetDuration() *durationpb.Duration { + if x != nil { + return x.Duration + } + return nil +} + +func (x *StartCaptureRequest) GetFilterExpr() string { + if x != nil { + return x.FilterExpr + } + return "" +} + +func (x *StartCaptureRequest) GetVerbose() bool { + if x != nil { + return x.Verbose + } + return false +} + +func (x *StartCaptureRequest) GetAscii() bool { + if x != nil { + return x.Ascii + } + return false +} + +type CapturePacket struct { + state protoimpl.MessageState `protogen:"open.v1"` + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *CapturePacket) Reset() { + *x = CapturePacket{} + mi := &file_daemon_proto_msgTypes[91] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *CapturePacket) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CapturePacket) ProtoMessage() {} + +func (x *CapturePacket) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[91] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CapturePacket.ProtoReflect.Descriptor instead. +func (*CapturePacket) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{91} +} + +func (x *CapturePacket) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +type StartBundleCaptureRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // timeout auto-stops the capture after this duration. + // Clamped to a server-side maximum (10 minutes). Zero or unset defaults to the maximum. + Timeout *durationpb.Duration `protobuf:"bytes,1,opt,name=timeout,proto3" json:"timeout,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartBundleCaptureRequest) Reset() { + *x = StartBundleCaptureRequest{} + mi := &file_daemon_proto_msgTypes[92] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartBundleCaptureRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartBundleCaptureRequest) ProtoMessage() {} + +func (x *StartBundleCaptureRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[92] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartBundleCaptureRequest.ProtoReflect.Descriptor instead. +func (*StartBundleCaptureRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{92} +} + +func (x *StartBundleCaptureRequest) GetTimeout() *durationpb.Duration { + if x != nil { + return x.Timeout + } + return nil +} + +type StartBundleCaptureResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartBundleCaptureResponse) Reset() { + *x = StartBundleCaptureResponse{} + mi := &file_daemon_proto_msgTypes[93] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartBundleCaptureResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartBundleCaptureResponse) ProtoMessage() {} + +func (x *StartBundleCaptureResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[93] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartBundleCaptureResponse.ProtoReflect.Descriptor instead. +func (*StartBundleCaptureResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{93} +} + +type StopBundleCaptureRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StopBundleCaptureRequest) Reset() { + *x = StopBundleCaptureRequest{} + mi := &file_daemon_proto_msgTypes[94] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StopBundleCaptureRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StopBundleCaptureRequest) ProtoMessage() {} + +func (x *StopBundleCaptureRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[94] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StopBundleCaptureRequest.ProtoReflect.Descriptor instead. +func (*StopBundleCaptureRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{94} +} + +type StopBundleCaptureResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StopBundleCaptureResponse) Reset() { + *x = StopBundleCaptureResponse{} + mi := &file_daemon_proto_msgTypes[95] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StopBundleCaptureResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StopBundleCaptureResponse) ProtoMessage() {} + +func (x *StopBundleCaptureResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[95] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StopBundleCaptureResponse.ProtoReflect.Descriptor instead. +func (*StopBundleCaptureResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{95} +} + type PortInfo_Range struct { state protoimpl.MessageState `protogen:"open.v1"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` @@ -5979,7 +6261,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[91] + mi := &file_daemon_proto_msgTypes[97] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5991,7 +6273,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[91] + mi := &file_daemon_proto_msgTypes[97] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -6539,7 +6821,23 @@ const file_daemon_proto_rawDesc = "" + "\vservice_url\x18\x02 \x01(\tR\n" + "serviceUrl\x12\x16\n" + "\x06domain\x18\x03 \x01(\tR\x06domain\x12,\n" + - "\x12port_auto_assigned\x18\x04 \x01(\bR\x10portAutoAssigned*b\n" + + "\x12port_auto_assigned\x18\x04 \x01(\bR\x10portAutoAssigned\"\xd9\x01\n" + + "\x13StartCaptureRequest\x12\x1f\n" + + "\vtext_output\x18\x01 \x01(\bR\n" + + "textOutput\x12\x19\n" + + "\bsnap_len\x18\x02 \x01(\rR\asnapLen\x125\n" + + "\bduration\x18\x03 \x01(\v2\x19.google.protobuf.DurationR\bduration\x12\x1f\n" + + "\vfilter_expr\x18\x04 \x01(\tR\n" + + "filterExpr\x12\x18\n" + + "\averbose\x18\x05 \x01(\bR\averbose\x12\x14\n" + + "\x05ascii\x18\x06 \x01(\bR\x05ascii\"#\n" + + "\rCapturePacket\x12\x12\n" + + "\x04data\x18\x01 \x01(\fR\x04data\"P\n" + + "\x19StartBundleCaptureRequest\x123\n" + + "\atimeout\x18\x01 \x01(\v2\x19.google.protobuf.DurationR\atimeout\"\x1c\n" + + "\x1aStartBundleCaptureResponse\"\x1a\n" + + "\x18StopBundleCaptureRequest\"\x1b\n" + + "\x19StopBundleCaptureResponse*b\n" + "\bLogLevel\x12\v\n" + "\aUNKNOWN\x10\x00\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" + @@ -6557,7 +6855,7 @@ const file_daemon_proto_rawDesc = "" + "\n" + "EXPOSE_UDP\x10\x03\x12\x0e\n" + "\n" + - "EXPOSE_TLS\x10\x042\xfc\x15\n" + + "EXPOSE_TLS\x10\x042\xff\x17\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -6578,7 +6876,10 @@ const file_daemon_proto_rawDesc = "" + "CleanState\x12\x19.daemon.CleanStateRequest\x1a\x1a.daemon.CleanStateResponse\"\x00\x12H\n" + "\vDeleteState\x12\x1a.daemon.DeleteStateRequest\x1a\x1b.daemon.DeleteStateResponse\"\x00\x12u\n" + "\x1aSetSyncResponsePersistence\x12).daemon.SetSyncResponsePersistenceRequest\x1a*.daemon.SetSyncResponsePersistenceResponse\"\x00\x12H\n" + - "\vTracePacket\x12\x1a.daemon.TracePacketRequest\x1a\x1b.daemon.TracePacketResponse\"\x00\x12D\n" + + "\vTracePacket\x12\x1a.daemon.TracePacketRequest\x1a\x1b.daemon.TracePacketResponse\"\x00\x12F\n" + + "\fStartCapture\x12\x1b.daemon.StartCaptureRequest\x1a\x15.daemon.CapturePacket\"\x000\x01\x12]\n" + + "\x12StartBundleCapture\x12!.daemon.StartBundleCaptureRequest\x1a\".daemon.StartBundleCaptureResponse\"\x00\x12Z\n" + + "\x11StopBundleCapture\x12 .daemon.StopBundleCaptureRequest\x1a!.daemon.StopBundleCaptureResponse\"\x00\x12D\n" + "\x0fSubscribeEvents\x12\x18.daemon.SubscribeRequest\x1a\x13.daemon.SystemEvent\"\x000\x01\x12B\n" + "\tGetEvents\x12\x18.daemon.GetEventsRequest\x1a\x19.daemon.GetEventsResponse\"\x00\x12N\n" + "\rSwitchProfile\x12\x1c.daemon.SwitchProfileRequest\x1a\x1d.daemon.SwitchProfileResponse\"\x00\x12B\n" + @@ -6613,7 +6914,7 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 5) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 93) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 99) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel (ExposeProtocol)(0), // 1: daemon.ExposeProtocol @@ -6710,128 +7011,142 @@ var file_daemon_proto_goTypes = []any{ (*ExposeServiceRequest)(nil), // 92: daemon.ExposeServiceRequest (*ExposeServiceEvent)(nil), // 93: daemon.ExposeServiceEvent (*ExposeServiceReady)(nil), // 94: daemon.ExposeServiceReady - nil, // 95: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 96: daemon.PortInfo.Range - nil, // 97: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 98: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 99: google.protobuf.Timestamp + (*StartCaptureRequest)(nil), // 95: daemon.StartCaptureRequest + (*CapturePacket)(nil), // 96: daemon.CapturePacket + (*StartBundleCaptureRequest)(nil), // 97: daemon.StartBundleCaptureRequest + (*StartBundleCaptureResponse)(nil), // 98: daemon.StartBundleCaptureResponse + (*StopBundleCaptureRequest)(nil), // 99: daemon.StopBundleCaptureRequest + (*StopBundleCaptureResponse)(nil), // 100: daemon.StopBundleCaptureResponse + nil, // 101: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 102: daemon.PortInfo.Range + nil, // 103: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 104: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 105: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 2, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType - 98, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 28, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 99, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 99, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 98, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration - 26, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo - 23, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState - 22, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState - 21, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState - 20, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState - 24, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState - 25, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 58, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent - 27, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState - 34, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 95, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 96, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range - 35, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo - 35, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo - 36, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule - 0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel - 0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel - 44, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State - 53, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags - 55, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage - 3, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity - 4, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 99, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 97, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry - 58, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 98, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 71, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile - 1, // 33: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol - 94, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady - 33, // 35: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList - 8, // 36: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 10, // 37: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 12, // 38: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 14, // 39: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 16, // 40: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 18, // 41: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 29, // 42: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest - 31, // 43: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest - 31, // 44: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest - 5, // 45: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest - 38, // 46: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest - 40, // 47: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest - 42, // 48: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 45, // 49: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest - 47, // 50: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest - 49, // 51: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest - 51, // 52: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest - 54, // 53: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest - 57, // 54: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest - 59, // 55: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest - 61, // 56: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest - 63, // 57: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest - 65, // 58: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest - 67, // 59: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest - 69, // 60: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest - 72, // 61: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest - 74, // 62: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest - 76, // 63: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest - 78, // 64: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest - 80, // 65: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest - 82, // 66: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest - 84, // 67: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest - 86, // 68: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest - 88, // 69: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest - 6, // 70: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest - 90, // 71: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest - 92, // 72: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest - 9, // 73: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 11, // 74: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 13, // 75: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 15, // 76: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 17, // 77: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 19, // 78: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 30, // 79: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 32, // 80: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 32, // 81: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 37, // 82: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 39, // 83: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 41, // 84: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 43, // 85: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 46, // 86: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 48, // 87: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 50, // 88: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 52, // 89: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse - 56, // 90: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 58, // 91: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 60, // 92: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 62, // 93: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 64, // 94: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 66, // 95: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 68, // 96: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 70, // 97: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 73, // 98: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 75, // 99: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse - 77, // 100: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse - 79, // 101: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse - 81, // 102: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse - 83, // 103: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse - 85, // 104: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse - 87, // 105: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse - 89, // 106: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse - 7, // 107: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse - 91, // 108: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse - 93, // 109: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent - 73, // [73:110] is the sub-list for method output_type - 36, // [36:73] is the sub-list for method input_type - 36, // [36:36] is the sub-list for extension type_name - 36, // [36:36] is the sub-list for extension extendee - 0, // [0:36] is the sub-list for field type_name + 2, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType + 104, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 28, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus + 105, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 105, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 104, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 26, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo + 23, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState + 22, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState + 21, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState + 20, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState + 24, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState + 25, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState + 58, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent + 27, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState + 34, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network + 101, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 102, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 35, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo + 35, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo + 36, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule + 0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel + 0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel + 44, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State + 53, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags + 55, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage + 3, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity + 4, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category + 105, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 103, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 58, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent + 104, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 71, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile + 1, // 33: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol + 94, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady + 104, // 35: daemon.StartCaptureRequest.duration:type_name -> google.protobuf.Duration + 104, // 36: daemon.StartBundleCaptureRequest.timeout:type_name -> google.protobuf.Duration + 33, // 37: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList + 8, // 38: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 10, // 39: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 12, // 40: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 14, // 41: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 16, // 42: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 18, // 43: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 29, // 44: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 31, // 45: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 31, // 46: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest + 5, // 47: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest + 38, // 48: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest + 40, // 49: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest + 42, // 50: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest + 45, // 51: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest + 47, // 52: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest + 49, // 53: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest + 51, // 54: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest + 54, // 55: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest + 95, // 56: daemon.DaemonService.StartCapture:input_type -> daemon.StartCaptureRequest + 97, // 57: daemon.DaemonService.StartBundleCapture:input_type -> daemon.StartBundleCaptureRequest + 99, // 58: daemon.DaemonService.StopBundleCapture:input_type -> daemon.StopBundleCaptureRequest + 57, // 59: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest + 59, // 60: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest + 61, // 61: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest + 63, // 62: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest + 65, // 63: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest + 67, // 64: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest + 69, // 65: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest + 72, // 66: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest + 74, // 67: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest + 76, // 68: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest + 78, // 69: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest + 80, // 70: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest + 82, // 71: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest + 84, // 72: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest + 86, // 73: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest + 88, // 74: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest + 6, // 75: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest + 90, // 76: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest + 92, // 77: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest + 9, // 78: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 11, // 79: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 13, // 80: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 15, // 81: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 17, // 82: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 19, // 83: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 30, // 84: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 32, // 85: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 32, // 86: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 37, // 87: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 39, // 88: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 41, // 89: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 43, // 90: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 46, // 91: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 48, // 92: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 50, // 93: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 52, // 94: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse + 56, // 95: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 96, // 96: daemon.DaemonService.StartCapture:output_type -> daemon.CapturePacket + 98, // 97: daemon.DaemonService.StartBundleCapture:output_type -> daemon.StartBundleCaptureResponse + 100, // 98: daemon.DaemonService.StopBundleCapture:output_type -> daemon.StopBundleCaptureResponse + 58, // 99: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 60, // 100: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 62, // 101: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 64, // 102: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 66, // 103: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 68, // 104: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 70, // 105: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 73, // 106: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 75, // 107: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 77, // 108: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse + 79, // 109: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse + 81, // 110: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse + 83, // 111: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse + 85, // 112: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse + 87, // 113: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse + 89, // 114: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse + 7, // 115: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse + 91, // 116: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse + 93, // 117: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent + 78, // [78:118] is the sub-list for method output_type + 38, // [38:78] is the sub-list for method input_type + 38, // [38:38] is the sub-list for extension type_name + 38, // [38:38] is the sub-list for extension extendee + 0, // [0:38] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -6861,7 +7176,7 @@ func file_daemon_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), NumEnums: 5, - NumMessages: 93, + NumMessages: 99, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 89302c8c3..b8d1418f3 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -64,6 +64,17 @@ service DaemonService { rpc TracePacket(TracePacketRequest) returns (TracePacketResponse) {} + // StartCapture begins streaming packet capture on the WireGuard interface. + // Requires --enable-capture set at service install/reconfigure time. + rpc StartCapture(StartCaptureRequest) returns (stream CapturePacket) {} + + // StartBundleCapture begins capturing packets to a server-side temp file + // for inclusion in the next debug bundle. Auto-stops after the given timeout. + rpc StartBundleCapture(StartBundleCaptureRequest) returns (StartBundleCaptureResponse) {} + + // StopBundleCapture stops the running bundle capture. Idempotent. + rpc StopBundleCapture(StopBundleCaptureRequest) returns (StopBundleCaptureResponse) {} + rpc SubscribeEvents(SubscribeRequest) returns (stream SystemEvent) {} rpc GetEvents(GetEventsRequest) returns (GetEventsResponse) {} @@ -847,3 +858,26 @@ message ExposeServiceReady { string domain = 3; bool port_auto_assigned = 4; } + +message StartCaptureRequest { + bool text_output = 1; + uint32 snap_len = 2; + google.protobuf.Duration duration = 3; + string filter_expr = 4; + bool verbose = 5; + bool ascii = 6; +} + +message CapturePacket { + bytes data = 1; +} + +message StartBundleCaptureRequest { + // timeout auto-stops the capture after this duration. + // Clamped to a server-side maximum (10 minutes). Zero or unset defaults to the maximum. + google.protobuf.Duration timeout = 1; +} + +message StartBundleCaptureResponse {} +message StopBundleCaptureRequest {} +message StopBundleCaptureResponse {} diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index e5bd89597..cfca1c09b 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -53,6 +53,14 @@ type DaemonServiceClient interface { // SetSyncResponsePersistence enables or disables sync response persistence SetSyncResponsePersistence(ctx context.Context, in *SetSyncResponsePersistenceRequest, opts ...grpc.CallOption) (*SetSyncResponsePersistenceResponse, error) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) + // StartCapture begins streaming packet capture on the WireGuard interface. + // Requires --enable-capture set at service install/reconfigure time. + StartCapture(ctx context.Context, in *StartCaptureRequest, opts ...grpc.CallOption) (DaemonService_StartCaptureClient, error) + // StartBundleCapture begins capturing packets to a server-side temp file + // for inclusion in the next debug bundle. Auto-stops after the given timeout. + StartBundleCapture(ctx context.Context, in *StartBundleCaptureRequest, opts ...grpc.CallOption) (*StartBundleCaptureResponse, error) + // StopBundleCapture stops the running bundle capture. Idempotent. + StopBundleCapture(ctx context.Context, in *StopBundleCaptureRequest, opts ...grpc.CallOption) (*StopBundleCaptureResponse, error) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) SwitchProfile(ctx context.Context, in *SwitchProfileRequest, opts ...grpc.CallOption) (*SwitchProfileResponse, error) @@ -253,8 +261,58 @@ func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRe return out, nil } +func (c *daemonServiceClient) StartCapture(ctx context.Context, in *StartCaptureRequest, opts ...grpc.CallOption) (DaemonService_StartCaptureClient, error) { + stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], "/daemon.DaemonService/StartCapture", opts...) + if err != nil { + return nil, err + } + x := &daemonServiceStartCaptureClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type DaemonService_StartCaptureClient interface { + Recv() (*CapturePacket, error) + grpc.ClientStream +} + +type daemonServiceStartCaptureClient struct { + grpc.ClientStream +} + +func (x *daemonServiceStartCaptureClient) Recv() (*CapturePacket, error) { + m := new(CapturePacket) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *daemonServiceClient) StartBundleCapture(ctx context.Context, in *StartBundleCaptureRequest, opts ...grpc.CallOption) (*StartBundleCaptureResponse, error) { + out := new(StartBundleCaptureResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/StartBundleCapture", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) StopBundleCapture(ctx context.Context, in *StopBundleCaptureRequest, opts ...grpc.CallOption) (*StopBundleCaptureResponse, error) { + out := new(StopBundleCaptureResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/StopBundleCapture", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error) { - stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], "/daemon.DaemonService/SubscribeEvents", opts...) + stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], "/daemon.DaemonService/SubscribeEvents", opts...) if err != nil { return nil, err } @@ -439,7 +497,7 @@ func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *Instal } func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error) { - stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], "/daemon.DaemonService/ExposeService", opts...) + stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[2], "/daemon.DaemonService/ExposeService", opts...) if err != nil { return nil, err } @@ -509,6 +567,14 @@ type DaemonServiceServer interface { // SetSyncResponsePersistence enables or disables sync response persistence SetSyncResponsePersistence(context.Context, *SetSyncResponsePersistenceRequest) (*SetSyncResponsePersistenceResponse, error) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) + // StartCapture begins streaming packet capture on the WireGuard interface. + // Requires --enable-capture set at service install/reconfigure time. + StartCapture(*StartCaptureRequest, DaemonService_StartCaptureServer) error + // StartBundleCapture begins capturing packets to a server-side temp file + // for inclusion in the next debug bundle. Auto-stops after the given timeout. + StartBundleCapture(context.Context, *StartBundleCaptureRequest) (*StartBundleCaptureResponse, error) + // StopBundleCapture stops the running bundle capture. Idempotent. + StopBundleCapture(context.Context, *StopBundleCaptureRequest) (*StopBundleCaptureResponse, error) SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) SwitchProfile(context.Context, *SwitchProfileRequest) (*SwitchProfileResponse, error) @@ -598,6 +664,15 @@ func (UnimplementedDaemonServiceServer) SetSyncResponsePersistence(context.Conte func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method TracePacket not implemented") } +func (UnimplementedDaemonServiceServer) StartCapture(*StartCaptureRequest, DaemonService_StartCaptureServer) error { + return status.Errorf(codes.Unimplemented, "method StartCapture not implemented") +} +func (UnimplementedDaemonServiceServer) StartBundleCapture(context.Context, *StartBundleCaptureRequest) (*StartBundleCaptureResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method StartBundleCapture not implemented") +} +func (UnimplementedDaemonServiceServer) StopBundleCapture(context.Context, *StopBundleCaptureRequest) (*StopBundleCaptureResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method StopBundleCapture not implemented") +} func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error { return status.Errorf(codes.Unimplemented, "method SubscribeEvents not implemented") } @@ -992,6 +1067,63 @@ func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, de return interceptor(ctx, in, info, handler) } +func _DaemonService_StartCapture_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(StartCaptureRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(DaemonServiceServer).StartCapture(m, &daemonServiceStartCaptureServer{stream}) +} + +type DaemonService_StartCaptureServer interface { + Send(*CapturePacket) error + grpc.ServerStream +} + +type daemonServiceStartCaptureServer struct { + grpc.ServerStream +} + +func (x *daemonServiceStartCaptureServer) Send(m *CapturePacket) error { + return x.ServerStream.SendMsg(m) +} + +func _DaemonService_StartBundleCapture_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(StartBundleCaptureRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).StartBundleCapture(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/StartBundleCapture", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).StartBundleCapture(ctx, req.(*StartBundleCaptureRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_StopBundleCapture_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(StopBundleCaptureRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).StopBundleCapture(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/StopBundleCapture", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).StopBundleCapture(ctx, req.(*StopBundleCaptureRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _DaemonService_SubscribeEvents_Handler(srv interface{}, stream grpc.ServerStream) error { m := new(SubscribeRequest) if err := stream.RecvMsg(m); err != nil { @@ -1419,6 +1551,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "TracePacket", Handler: _DaemonService_TracePacket_Handler, }, + { + MethodName: "StartBundleCapture", + Handler: _DaemonService_StartBundleCapture_Handler, + }, + { + MethodName: "StopBundleCapture", + Handler: _DaemonService_StopBundleCapture_Handler, + }, { MethodName: "GetEvents", Handler: _DaemonService_GetEvents_Handler, @@ -1489,6 +1629,11 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ }, }, Streams: []grpc.StreamDesc{ + { + StreamName: "StartCapture", + Handler: _DaemonService_StartCapture_Handler, + ServerStreams: true, + }, { StreamName: "SubscribeEvents", Handler: _DaemonService_SubscribeEvents_Handler, diff --git a/client/server/capture.go b/client/server/capture.go new file mode 100644 index 000000000..17dbf9fd6 --- /dev/null +++ b/client/server/capture.go @@ -0,0 +1,325 @@ +package server + +import ( + "context" + "io" + "os" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/util/capture" +) + +const maxBundleCaptureDuration = 10 * time.Minute + +// bundleCapture holds the state of an in-progress capture destined for the +// debug bundle. The lifecycle is: +// +// StartBundleCapture → capture running, writing to temp file +// StopBundleCapture → capture stopped, temp file available +// DebugBundle → temp file included in zip, then cleaned up +type bundleCapture struct { + mu sync.Mutex + sess *capture.Session + file *os.File + engine *internal.Engine + cancel context.CancelFunc + stopped bool +} + +// stop halts the capture session and closes the pcap writer. Idempotent. +func (bc *bundleCapture) stop() { + bc.mu.Lock() + defer bc.mu.Unlock() + + if bc.stopped { + return + } + bc.stopped = true + + if bc.cancel != nil { + bc.cancel() + } + if bc.engine != nil { + if err := bc.engine.SetCapture(nil); err != nil { + log.Debugf("clear bundle capture: %v", err) + } + } + if bc.sess != nil { + bc.sess.Stop() + } +} + +// path returns the temp file path, or "" if no file exists. +func (bc *bundleCapture) path() string { + if bc.file == nil { + return "" + } + return bc.file.Name() +} + +// cleanup removes the temp file. +func (bc *bundleCapture) cleanup() { + if bc.file == nil { + return + } + name := bc.file.Name() + if err := bc.file.Close(); err != nil { + log.Debugf("close bundle capture file: %v", err) + } + if err := os.Remove(name); err != nil && !os.IsNotExist(err) { + log.Debugf("remove bundle capture file: %v", err) + } + bc.file = nil +} + +// StartCapture streams a pcap or text packet capture over gRPC. +// Gated by the --enable-capture service flag. +func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.DaemonService_StartCaptureServer) error { + if !s.captureEnabled { + return status.Error(codes.PermissionDenied, + "packet capture is disabled; reinstall or reconfigure the service with --enable-capture") + } + + engine, err := s.getCaptureEngine() + if err != nil { + return err + } + + matcher, err := parseCaptureFilter(req) + if err != nil { + return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err) + } + + pr, pw := io.Pipe() + + opts := capture.Options{ + Matcher: matcher, + SnapLen: req.GetSnapLen(), + Verbose: req.GetVerbose(), + ASCII: req.GetAscii(), + } + if req.GetTextOutput() { + opts.TextOutput = pw + } else { + opts.Output = pw + } + + sess, err := capture.NewSession(opts) + if err != nil { + pw.Close() + return status.Errorf(codes.Internal, "create capture session: %v", err) + } + + if err := engine.SetCapture(sess); err != nil { + sess.Stop() + pw.Close() + return status.Errorf(codes.Internal, "set capture: %v", err) + } + + // Send an empty initial message to signal that the capture was accepted. + // The client waits for this before printing the banner, so it must arrive + // before any packet data. + if err := stream.Send(&proto.CapturePacket{}); err != nil { + if clearErr := engine.SetCapture(nil); clearErr != nil { + log.Debugf("clear capture after send failure: %v", clearErr) + } + sess.Stop() + pw.Close() + return status.Errorf(codes.Internal, "send initial message: %v", err) + } + + ctx := stream.Context() + if d := req.GetDuration(); d != nil { + dur := d.AsDuration() + if dur < 0 { + if clearErr := engine.SetCapture(nil); clearErr != nil { + log.Debugf("clear capture: %v", clearErr) + } + sess.Stop() + pw.Close() + return status.Errorf(codes.InvalidArgument, "duration must not be negative") + } + if dur > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, dur) + defer cancel() + } + } + + go func() { + <-ctx.Done() + if err := engine.SetCapture(nil); err != nil { + log.Debugf("clear capture: %v", err) + } + sess.Stop() + pw.Close() + }() + defer pr.Close() + + log.Infof("packet capture started (text=%v, expr=%q)", req.GetTextOutput(), req.GetFilterExpr()) + defer func() { + stats := sess.Stats() + log.Infof("packet capture stopped: %d packets, %d bytes, %d dropped", + stats.Packets, stats.Bytes, stats.Dropped) + }() + + return streamToGRPC(pr, stream) +} + +func streamToGRPC(r io.Reader, stream proto.DaemonService_StartCaptureServer) error { + buf := make([]byte, 32*1024) + for { + n, readErr := r.Read(buf) + if n > 0 { + if err := stream.Send(&proto.CapturePacket{Data: buf[:n]}); err != nil { + log.Debugf("capture stream send: %v", err) + return nil //nolint:nilerr // client disconnected + } + } + if readErr != nil { + return nil //nolint:nilerr // pipe closed, capture stopped normally + } + } +} + +// StartBundleCapture begins capturing packets to a server-side temp file for +// inclusion in the next debug bundle. Not gated by --enable-capture since the +// output stays on the server (same trust level as CPU profiling). +// +// A timeout auto-stops the capture as a safety net if StopBundleCapture is +// never called (e.g. CLI crash). +func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCaptureRequest) (*proto.StartBundleCaptureResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.stopBundleCaptureLocked() + s.cleanupBundleCapture() + + engine, err := s.getCaptureEngineLocked() + if err != nil { + // Not fatal: kernel mode or not connected. Log and return success + // so the debug bundle still generates without capture data. + log.Warnf("packet capture unavailable, skipping: %v", err) + return &proto.StartBundleCaptureResponse{}, nil + } + + timeout := req.GetTimeout().AsDuration() + if timeout <= 0 || timeout > maxBundleCaptureDuration { + timeout = maxBundleCaptureDuration + } + + f, err := os.CreateTemp("", "netbird.capture.*.pcap") + if err != nil { + return nil, status.Errorf(codes.Internal, "create temp file: %v", err) + } + + sess, err := capture.NewSession(capture.Options{Output: f}) + if err != nil { + f.Close() + os.Remove(f.Name()) + return nil, status.Errorf(codes.Internal, "create capture session: %v", err) + } + + if err := engine.SetCapture(sess); err != nil { + sess.Stop() + f.Close() + os.Remove(f.Name()) + log.Warnf("packet capture unavailable (no filtered device), skipping: %v", err) + return &proto.StartBundleCaptureResponse{}, nil + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + bc := &bundleCapture{ + sess: sess, + file: f, + engine: engine, + cancel: cancel, + } + + go func() { + <-ctx.Done() + bc.stop() + log.Infof("bundle capture auto-stopped after timeout") + }() + + s.bundleCapture = bc + log.Infof("bundle capture started (timeout=%s, file=%s)", timeout, f.Name()) + + return &proto.StartBundleCaptureResponse{}, nil +} + +// StopBundleCapture stops the running bundle capture. Idempotent. +func (s *Server) StopBundleCapture(_ context.Context, _ *proto.StopBundleCaptureRequest) (*proto.StopBundleCaptureResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.stopBundleCaptureLocked() + return &proto.StopBundleCaptureResponse{}, nil +} + +// stopBundleCaptureLocked stops the bundle capture if running. Must hold s.mutex. +func (s *Server) stopBundleCaptureLocked() { + if s.bundleCapture == nil { + return + } + s.bundleCapture.stop() + + stats := s.bundleCapture.sess.Stats() + log.Infof("bundle capture stopped: %d packets, %d bytes, %d dropped", + stats.Packets, stats.Bytes, stats.Dropped) +} + +// bundleCapturePath returns the temp file path if a capture has been taken, +// stops any running capture, and returns "". Called from DebugBundle. +// Must hold s.mutex. +func (s *Server) bundleCapturePath() string { + if s.bundleCapture == nil { + return "" + } + + s.bundleCapture.stop() + return s.bundleCapture.path() +} + +// cleanupBundleCapture removes the temp file and clears state. Must hold s.mutex. +func (s *Server) cleanupBundleCapture() { + if s.bundleCapture == nil { + return + } + s.bundleCapture.cleanup() + s.bundleCapture = nil +} + +func (s *Server) getCaptureEngine() (*internal.Engine, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + return s.getCaptureEngineLocked() +} + +func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) { + if s.connectClient == nil { + return nil, status.Error(codes.FailedPrecondition, "client not connected") + } + engine := s.connectClient.Engine() + if engine == nil { + return nil, status.Error(codes.FailedPrecondition, "engine not initialized") + } + return engine, nil +} + +// parseCaptureFilter returns a Matcher from the request. +// Returns nil (match all) when no filter expression is set. +func parseCaptureFilter(req *proto.StartCaptureRequest) (capture.Matcher, error) { + expr := req.GetFilterExpr() + if expr == "" { + return nil, nil //nolint:nilnil // nil Matcher means "match all" + } + return capture.ParseFilter(expr) +} diff --git a/client/server/debug.go b/client/server/debug.go index 81708e576..33247db5f 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -43,7 +43,9 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( }() } - // Prepare refresh callback for health probes + capturePath := s.bundleCapturePath() + defer s.cleanupBundleCapture() + var refreshStatus func() if s.connectClient != nil { engine := s.connectClient.Engine() @@ -62,6 +64,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( SyncResponse: syncResponse, LogPath: s.logFile, CPUProfile: cpuProfileData, + CapturePath: capturePath, RefreshStatus: refreshStatus, ClientMetrics: clientMetrics, }, diff --git a/client/server/server.go b/client/server/server.go index e12b6df5b..78d645d12 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -88,6 +88,8 @@ type Server struct { profileManager *profilemanager.ServiceManager profilesDisabled bool updateSettingsDisabled bool + captureEnabled bool + bundleCapture *bundleCapture sleepHandler *sleephandler.SleepHandler @@ -104,7 +106,7 @@ type oauthAuthFlow struct { } // New server instance constructor. -func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server { +func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, captureEnabled bool) *Server { s := &Server{ rootCtx: ctx, logFile: logFile, @@ -113,6 +115,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable profileManager: profilemanager.NewServiceManager(configFile), profilesDisabled: profilesDisabled, updateSettingsDisabled: updateSettingsDisabled, + captureEnabled: captureEnabled, jwtCache: newJWTCache(), } agent := &serverAgent{s} diff --git a/client/server/server_test.go b/client/server/server_test.go index 6de23d501..c5148104f 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -103,7 +103,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "debug", "", false, false) + s := New(ctx, "debug", "", false, false, false) s.config = config @@ -164,7 +164,7 @@ func TestServer_Up(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "console", "", false, false) + s := New(ctx, "console", "", false, false, false) err = s.Start() require.NoError(t, err) @@ -234,7 +234,7 @@ func TestServer_SubcribeEvents(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "console", "", false, false) + s := New(ctx, "console", "", false, false, false) err = s.Start() require.NoError(t, err) diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go index 8e360175d..7f6847c43 100644 --- a/client/server/setconfig_test.go +++ b/client/server/setconfig_test.go @@ -53,7 +53,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { require.NoError(t, err) ctx := context.Background() - s := New(ctx, "console", "", false, false) + s := New(ctx, "console", "", false, false, false) rosenpassEnabled := true rosenpassPermissive := true diff --git a/client/ui/debug.go b/client/ui/debug.go index 4ebe4d675..cf5ac1a75 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -16,6 +16,7 @@ import ( "fyne.io/fyne/v2/widget" log "github.com/sirupsen/logrus" "github.com/skratchdot/open-golang/open" + "google.golang.org/protobuf/types/known/durationpb" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/proto" @@ -38,6 +39,7 @@ type debugCollectionParams struct { upload bool uploadURL string enablePersistence bool + capture bool } // UI components for progress tracking @@ -51,25 +53,58 @@ type progressUI struct { func (s *serviceClient) showDebugUI() { w := s.app.NewWindow("NetBird Debug") w.SetOnClosed(s.cancel) - w.Resize(fyne.NewSize(600, 500)) w.SetFixedSize(true) anonymizeCheck := widget.NewCheck("Anonymize sensitive information (public IPs, domains, ...)", nil) systemInfoCheck := widget.NewCheck("Include system information (routes, interfaces, ...)", nil) systemInfoCheck.SetChecked(true) + captureCheck := widget.NewCheck("Include packet capture", nil) uploadCheck := widget.NewCheck("Upload bundle automatically after creation", nil) uploadCheck.SetChecked(true) - uploadURLLabel := widget.NewLabel("Debug upload URL:") + uploadURLContainer, uploadURL := s.buildUploadSection(uploadCheck) + + debugModeContainer, runForDurationCheck, durationInput, noteLabel := s.buildDurationSection() + + statusLabel := widget.NewLabel("") + statusLabel.Hide() + progressBar := widget.NewProgressBar() + progressBar.Hide() + createButton := widget.NewButton("Create Debug Bundle", nil) + + uiControls := []fyne.Disableable{ + anonymizeCheck, systemInfoCheck, captureCheck, + uploadCheck, uploadURL, runForDurationCheck, durationInput, createButton, + } + + createButton.OnTapped = s.getCreateHandler( + statusLabel, progressBar, uploadCheck, uploadURL, + anonymizeCheck, systemInfoCheck, captureCheck, + runForDurationCheck, durationInput, uiControls, w, + ) + + content := container.NewVBox( + widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"), + widget.NewLabel(""), + anonymizeCheck, systemInfoCheck, captureCheck, + uploadCheck, uploadURLContainer, + widget.NewLabel(""), + debugModeContainer, noteLabel, + widget.NewLabel(""), + statusLabel, progressBar, createButton, + ) + + w.SetContent(container.NewPadded(content)) + w.Show() +} + +func (s *serviceClient) buildUploadSection(uploadCheck *widget.Check) (*fyne.Container, *widget.Entry) { uploadURL := widget.NewEntry() uploadURL.SetText(uptypes.DefaultBundleURL) uploadURL.SetPlaceHolder("Enter upload URL") - uploadURLContainer := container.NewVBox( - uploadURLLabel, - uploadURL, - ) + uploadURLContainer := container.NewVBox(widget.NewLabel("Debug upload URL:"), uploadURL) uploadCheck.OnChanged = func(checked bool) { if checked { @@ -78,13 +113,14 @@ func (s *serviceClient) showDebugUI() { uploadURLContainer.Hide() } } + return uploadURLContainer, uploadURL +} - debugModeContainer := container.NewHBox() +func (s *serviceClient) buildDurationSection() (*fyne.Container, *widget.Check, *widget.Entry, *widget.Label) { runForDurationCheck := widget.NewCheck("Run with trace logs before creating bundle", nil) runForDurationCheck.SetChecked(true) forLabel := widget.NewLabel("for") - durationInput := widget.NewEntry() durationInput.SetText("1") minutesLabel := widget.NewLabel("minute") @@ -108,63 +144,8 @@ func (s *serviceClient) showDebugUI() { } } - debugModeContainer.Add(runForDurationCheck) - debugModeContainer.Add(forLabel) - debugModeContainer.Add(durationInput) - debugModeContainer.Add(minutesLabel) - - statusLabel := widget.NewLabel("") - statusLabel.Hide() - - progressBar := widget.NewProgressBar() - progressBar.Hide() - - createButton := widget.NewButton("Create Debug Bundle", nil) - - // UI controls that should be disabled during debug collection - uiControls := []fyne.Disableable{ - anonymizeCheck, - systemInfoCheck, - uploadCheck, - uploadURL, - runForDurationCheck, - durationInput, - createButton, - } - - createButton.OnTapped = s.getCreateHandler( - statusLabel, - progressBar, - uploadCheck, - uploadURL, - anonymizeCheck, - systemInfoCheck, - runForDurationCheck, - durationInput, - uiControls, - w, - ) - - content := container.NewVBox( - widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"), - widget.NewLabel(""), - anonymizeCheck, - systemInfoCheck, - uploadCheck, - uploadURLContainer, - widget.NewLabel(""), - debugModeContainer, - noteLabel, - widget.NewLabel(""), - statusLabel, - progressBar, - createButton, - ) - - paddedContent := container.NewPadded(content) - w.SetContent(paddedContent) - - w.Show() + modeContainer := container.NewHBox(runForDurationCheck, forLabel, durationInput, minutesLabel) + return modeContainer, runForDurationCheck, durationInput, noteLabel } func validateMinute(s string, minutesLabel *widget.Label) error { @@ -200,6 +181,7 @@ func (s *serviceClient) getCreateHandler( uploadURL *widget.Entry, anonymizeCheck *widget.Check, systemInfoCheck *widget.Check, + captureCheck *widget.Check, runForDurationCheck *widget.Check, duration *widget.Entry, uiControls []fyne.Disableable, @@ -222,6 +204,7 @@ func (s *serviceClient) getCreateHandler( params := &debugCollectionParams{ anonymize: anonymizeCheck.Checked, systemInfo: systemInfoCheck.Checked, + capture: captureCheck.Checked, upload: uploadCheck.Checked, uploadURL: url, enablePersistence: true, @@ -253,10 +236,7 @@ func (s *serviceClient) getCreateHandler( statusLabel.SetText("Creating debug bundle...") go s.handleDebugCreation( - anonymizeCheck.Checked, - systemInfoCheck.Checked, - uploadCheck.Checked, - url, + params, statusLabel, uiControls, w, @@ -371,7 +351,7 @@ func startProgressTracker(ctx context.Context, wg *sync.WaitGroup, duration time func (s *serviceClient) configureServiceForDebug( conn proto.DaemonServiceClient, state *debugInitialState, - enablePersistence bool, + params *debugCollectionParams, ) { if state.wasDown { if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { @@ -397,7 +377,7 @@ func (s *serviceClient) configureServiceForDebug( time.Sleep(time.Second) } - if enablePersistence { + if params.enablePersistence { if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{ Enabled: true, }); err != nil { @@ -417,6 +397,26 @@ func (s *serviceClient) configureServiceForDebug( if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil { log.Warnf("failed to start CPU profiling: %v", err) } + + s.startBundleCaptureIfEnabled(conn, params) +} + +func (s *serviceClient) startBundleCaptureIfEnabled(conn proto.DaemonServiceClient, params *debugCollectionParams) { + if !params.capture { + return + } + + const maxCapture = 10 * time.Minute + timeout := params.duration + 30*time.Second + if timeout > maxCapture { + timeout = maxCapture + log.Warnf("packet capture clamped to %s (server maximum)", maxCapture) + } + if _, err := conn.StartBundleCapture(s.ctx, &proto.StartBundleCaptureRequest{ + Timeout: durationpb.New(timeout), + }); err != nil { + log.Warnf("failed to start bundle capture: %v", err) + } } func (s *serviceClient) collectDebugData( @@ -430,7 +430,7 @@ func (s *serviceClient) collectDebugData( var wg sync.WaitGroup startProgressTracker(ctx, &wg, params.duration, progress) - s.configureServiceForDebug(conn, state, params.enablePersistence) + s.configureServiceForDebug(conn, state, params) wg.Wait() progress.progressBar.Hide() @@ -440,6 +440,14 @@ func (s *serviceClient) collectDebugData( log.Warnf("failed to stop CPU profiling: %v", err) } + if params.capture { + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := conn.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil { + log.Warnf("failed to stop bundle capture: %v", err) + } + } + return nil } @@ -520,18 +528,37 @@ func handleError(progress *progressUI, errMsg string) { } func (s *serviceClient) handleDebugCreation( - anonymize bool, - systemInfo bool, - upload bool, - uploadURL string, + params *debugCollectionParams, statusLabel *widget.Label, uiControls []fyne.Disableable, w fyne.Window, ) { - log.Infof("Creating debug bundle (Anonymized: %v, System Info: %v, Upload Attempt: %v)...", - anonymize, systemInfo, upload) + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + log.Errorf("Failed to get client for debug: %v", err) + statusLabel.SetText(fmt.Sprintf("Error: %v", err)) + enableUIControls(uiControls) + return + } - resp, err := s.createDebugBundle(anonymize, systemInfo, uploadURL) + if params.capture { + if _, err := conn.StartBundleCapture(s.ctx, &proto.StartBundleCaptureRequest{ + Timeout: durationpb.New(30 * time.Second), + }); err != nil { + log.Warnf("failed to start bundle capture: %v", err) + } else { + defer func() { + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := conn.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil { + log.Warnf("failed to stop bundle capture: %v", err) + } + }() + time.Sleep(2 * time.Second) + } + } + + resp, err := s.createDebugBundle(params.anonymize, params.systemInfo, params.uploadURL) if err != nil { log.Errorf("Failed to create debug bundle: %v", err) statusLabel.SetText(fmt.Sprintf("Error creating bundle: %v", err)) @@ -543,7 +570,7 @@ func (s *serviceClient) handleDebugCreation( uploadFailureReason := resp.GetUploadFailureReason() uploadedKey := resp.GetUploadedKey() - if upload { + if params.upload { if uploadFailureReason != "" { showUploadFailedDialog(w, localPath, uploadFailureReason) } else { diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index d8e50ab6d..cb512f132 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -5,6 +5,7 @@ package main import ( "context" "fmt" + "sync" "syscall/js" "time" @@ -14,6 +15,7 @@ import ( netbird "github.com/netbirdio/netbird/client/embed" sshdetection "github.com/netbirdio/netbird/client/ssh/detection" nbstatus "github.com/netbirdio/netbird/client/status" + wasmcapture "github.com/netbirdio/netbird/client/wasm/internal/capture" "github.com/netbirdio/netbird/client/wasm/internal/http" "github.com/netbirdio/netbird/client/wasm/internal/rdp" "github.com/netbirdio/netbird/client/wasm/internal/ssh" @@ -459,6 +461,95 @@ func createSetLogLevelMethod(client *netbird.Client) js.Func { }) } +// createStartCaptureMethod creates the programmable packet capture method. +// Returns a JS interface with onpacket callback and stop() method. +// +// Usage from JavaScript: +// +// const cap = await client.startCapture({ filter: "tcp port 443", verbose: true }) +// cap.onpacket = (line) => console.log(line) +// const stats = cap.stop() +func createStartCaptureMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(_ js.Value, args []js.Value) any { + var opts js.Value + if len(args) > 0 { + opts = args[0] + } + + return createPromise(func(resolve, reject js.Value) { + iface, err := wasmcapture.Start(client, opts) + if err != nil { + reject.Invoke(js.ValueOf(fmt.Sprintf("start capture: %v", err))) + return + } + resolve.Invoke(iface) + }) + }) +} + +// captureMethods returns capture() and stopCapture() that share state for +// the console-log shortcut. capture() logs packets to the browser console +// and stopCapture() ends it, like Ctrl+C on the CLI. +// +// Usage from browser devtools console: +// +// await client.capture() // capture all packets +// await client.capture("tcp") // capture with filter +// await client.capture({filter: "host 10.0.0.1", verbose: true}) +// client.stopCapture() // stop and print stats +func captureMethods(client *netbird.Client) (startFn, stopFn js.Func) { + var mu sync.Mutex + var active *wasmcapture.Handle + + startFn = js.FuncOf(func(_ js.Value, args []js.Value) any { + var opts js.Value + if len(args) > 0 { + opts = args[0] + } + + return createPromise(func(resolve, reject js.Value) { + mu.Lock() + defer mu.Unlock() + + if active != nil { + active.Stop() + active = nil + } + + h, err := wasmcapture.StartConsole(client, opts) + if err != nil { + reject.Invoke(js.ValueOf(fmt.Sprintf("start capture: %v", err))) + return + } + active = h + + console := js.Global().Get("console") + console.Call("log", "[capture] started, call client.stopCapture() to stop") + resolve.Invoke(js.Undefined()) + }) + }) + + stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any { + mu.Lock() + defer mu.Unlock() + + if active == nil { + js.Global().Get("console").Call("log", "[capture] no active capture") + return js.Undefined() + } + + stats := active.Stop() + active = nil + + console := js.Global().Get("console") + console.Call("log", fmt.Sprintf("[capture] stopped: %d packets, %d bytes, %d dropped", + stats.Packets, stats.Bytes, stats.Dropped)) + return js.Undefined() + }) + + return startFn, stopFn +} + // createPromise is a helper to create JavaScript promises func createPromise(handler func(resolve, reject js.Value)) js.Value { return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { @@ -521,6 +612,11 @@ func createClientObject(client *netbird.Client) js.Value { obj["statusDetail"] = createStatusDetailMethod(client) obj["getSyncResponse"] = createGetSyncResponseMethod(client) obj["setLogLevel"] = createSetLogLevelMethod(client) + obj["startCapture"] = createStartCaptureMethod(client) + + capStart, capStop := captureMethods(client) + obj["capture"] = capStart + obj["stopCapture"] = capStop return js.ValueOf(obj) } diff --git a/client/wasm/internal/capture/capture.go b/client/wasm/internal/capture/capture.go new file mode 100644 index 000000000..f6dc263f8 --- /dev/null +++ b/client/wasm/internal/capture/capture.go @@ -0,0 +1,211 @@ +//go:build js + +// Package capture bridges the util/capture package to JavaScript via syscall/js. +package capture + +import ( + "strings" + "sync" + "syscall/js" + + log "github.com/sirupsen/logrus" + + netbird "github.com/netbirdio/netbird/client/embed" + "github.com/netbirdio/netbird/util/capture" +) + +// Handle holds a running capture session and the embedded client reference +// so it can be stopped later. +type Handle struct { + client *netbird.Client + sess *capture.Session + stopFn js.Func + stopped bool +} + +// Stop ends the capture and returns stats. +func (h *Handle) Stop() capture.Stats { + if h.stopped { + return h.sess.Stats() + } + h.stopped = true + h.stopFn.Release() + + if err := h.client.SetCapture(nil); err != nil { + log.Debugf("clear capture: %v", err) + } + h.sess.Stop() + return h.sess.Stats() +} + +func statsToJS(s capture.Stats) js.Value { + obj := js.Global().Get("Object").Call("create", js.Null()) + obj.Set("packets", js.ValueOf(s.Packets)) + obj.Set("bytes", js.ValueOf(s.Bytes)) + obj.Set("dropped", js.ValueOf(s.Dropped)) + return obj +} + +// parseOpts extracts filter/verbose/ascii from a JS options value. +func parseOpts(jsOpts js.Value) (filter string, verbose, ascii bool) { + if jsOpts.IsNull() || jsOpts.IsUndefined() { + return + } + if jsOpts.Type() == js.TypeString { + filter = jsOpts.String() + return + } + if jsOpts.Type() != js.TypeObject { + return + } + if f := jsOpts.Get("filter"); !f.IsUndefined() && !f.IsNull() { + filter = f.String() + } + if v := jsOpts.Get("verbose"); !v.IsUndefined() { + verbose = v.Truthy() + } + if a := jsOpts.Get("ascii"); !a.IsUndefined() { + ascii = a.Truthy() + } + return +} + +func buildMatcher(filter string) (capture.Matcher, error) { + if filter == "" { + return nil, nil + } + return capture.ParseFilter(filter) +} + +// Start creates a capture session and returns a JS interface for streaming text +// output. The returned object exposes: +// +// onpacket(callback) - set callback(string) for each text line +// stop() - stop capture and return stats { packets, bytes, dropped } +// +// Options: { filter: string, verbose: bool, ascii: bool } or just a filter string. +func Start(client *netbird.Client, jsOpts js.Value) (js.Value, error) { + filter, verbose, ascii := parseOpts(jsOpts) + + matcher, err := buildMatcher(filter) + if err != nil { + return js.Undefined(), err + } + + cb := &jsCallbackWriter{} + + sess, err := capture.NewSession(capture.Options{ + TextOutput: cb, + Matcher: matcher, + Verbose: verbose, + ASCII: ascii, + }) + if err != nil { + return js.Undefined(), err + } + + if err := client.SetCapture(sess); err != nil { + sess.Stop() + return js.Undefined(), err + } + + handle := &Handle{client: client, sess: sess} + + iface := js.Global().Get("Object").Call("create", js.Null()) + handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any { + return statsToJS(handle.Stop()) + }) + iface.Set("stop", handle.stopFn) + iface.Set("onpacket", js.Undefined()) + cb.setInterface(iface) + + return iface, nil +} + +// StartConsole starts a capture that logs every packet line to console.log. +// Returns a Handle so the caller can stop it later. +func StartConsole(client *netbird.Client, jsOpts js.Value) (*Handle, error) { + filter, verbose, ascii := parseOpts(jsOpts) + + matcher, err := buildMatcher(filter) + if err != nil { + return nil, err + } + + cb := &jsCallbackWriter{} + + sess, err := capture.NewSession(capture.Options{ + TextOutput: cb, + Matcher: matcher, + Verbose: verbose, + ASCII: ascii, + }) + if err != nil { + return nil, err + } + + if err := client.SetCapture(sess); err != nil { + sess.Stop() + return nil, err + } + + handle := &Handle{client: client, sess: sess} + handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any { + return statsToJS(handle.Stop()) + }) + + iface := js.Global().Get("Object").Call("create", js.Null()) + console := js.Global().Get("console") + iface.Set("onpacket", console.Get("log").Call("bind", console, js.ValueOf("[capture]"))) + cb.setInterface(iface) + + return handle, nil +} + +// jsCallbackWriter is an io.Writer that buffers text until a newline, then +// invokes the JS onpacket callback with each complete line. +type jsCallbackWriter struct { + mu sync.Mutex + iface js.Value + buf strings.Builder +} + +func (w *jsCallbackWriter) setInterface(iface js.Value) { + w.mu.Lock() + defer w.mu.Unlock() + w.iface = iface +} + +func (w *jsCallbackWriter) Write(p []byte) (int, error) { + w.mu.Lock() + w.buf.Write(p) + + var lines []string + for { + str := w.buf.String() + idx := strings.IndexByte(str, '\n') + if idx < 0 { + break + } + lines = append(lines, str[:idx]) + w.buf.Reset() + if idx+1 < len(str) { + w.buf.WriteString(str[idx+1:]) + } + } + + iface := w.iface + w.mu.Unlock() + + if iface.IsUndefined() { + return len(p), nil + } + cb := iface.Get("onpacket") + if cb.IsUndefined() || cb.IsNull() { + return len(p), nil + } + for _, line := range lines { + cb.Invoke(js.ValueOf(line)) + } + return len(p), nil +} diff --git a/proxy/cmd/proxy/cmd/debug.go b/proxy/cmd/proxy/cmd/debug.go index 59f7a6b65..1b1664490 100644 --- a/proxy/cmd/proxy/cmd/debug.go +++ b/proxy/cmd/proxy/cmd/debug.go @@ -2,7 +2,12 @@ package cmd import ( "fmt" + "os" + "os/signal" + "path/filepath" "strconv" + "strings" + "syscall" "github.com/spf13/cobra" @@ -99,6 +104,27 @@ var debugStopCmd = &cobra.Command{ SilenceUsage: true, } +var debugCaptureCmd = &cobra.Command{ + Use: "capture [filter expression]", + Short: "Capture packets on a client's WireGuard interface", + Long: `Captures decrypted packets flowing through a client's WireGuard interface. + +Default output is human-readable text. Use --pcap or --output for pcap binary. +Filter arguments after the account ID use BPF-like syntax. + +Examples: + netbird-proxy debug capture + netbird-proxy debug capture --duration 1m host 10.0.0.1 + netbird-proxy debug capture host 10.0.0.1 and tcp port 443 + netbird-proxy debug capture not port 22 + netbird-proxy debug capture -o capture.pcap + netbird-proxy debug capture --pcap | tcpdump -r - -n + netbird-proxy debug capture --pcap | tshark -r -`, + Args: cobra.MinimumNArgs(1), + RunE: runDebugCapture, + SilenceUsage: true, +} + func init() { debugCmd.PersistentFlags().StringVar(&debugAddr, "addr", envStringOrDefault("NB_PROXY_DEBUG_ADDRESS", "localhost:8444"), "Debug endpoint address") debugCmd.PersistentFlags().BoolVar(&jsonOutput, "json", false, "Output JSON instead of pretty format") @@ -110,6 +136,12 @@ func init() { debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)") + debugCaptureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = server default)") + debugCaptureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)") + debugCaptureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length (text mode)") + debugCaptureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (text mode)") + debugCaptureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout") + debugCmd.AddCommand(debugHealthCmd) debugCmd.AddCommand(debugClientsCmd) debugCmd.AddCommand(debugStatusCmd) @@ -119,6 +151,7 @@ func init() { debugCmd.AddCommand(debugLogCmd) debugCmd.AddCommand(debugStartCmd) debugCmd.AddCommand(debugStopCmd) + debugCmd.AddCommand(debugCaptureCmd) rootCmd.AddCommand(debugCmd) } @@ -171,3 +204,84 @@ func runDebugStart(cmd *cobra.Command, args []string) error { func runDebugStop(cmd *cobra.Command, args []string) error { return getDebugClient(cmd).StopClient(cmd.Context(), args[0]) } + +func runDebugCapture(cmd *cobra.Command, args []string) error { + duration, _ := cmd.Flags().GetDuration("duration") + forcePcap, _ := cmd.Flags().GetBool("pcap") + verbose, _ := cmd.Flags().GetBool("verbose") + ascii, _ := cmd.Flags().GetBool("ascii") + outPath, _ := cmd.Flags().GetString("output") + + // Default to text. Use pcap when --pcap is set or --output is given. + wantText := !forcePcap && outPath == "" + + var filterExpr string + if len(args) > 1 { + filterExpr = strings.Join(args[1:], " ") + } + + ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + out, cleanup, err := captureOutputWriter(cmd, outPath) + if err != nil { + return err + } + defer cleanup() + + if wantText { + cmd.PrintErrln("Capturing packets... Press Ctrl+C to stop.") + } else { + cmd.PrintErrln("Capturing packets (pcap)... Press Ctrl+C to stop.") + } + + var durationStr string + if duration > 0 { + durationStr = duration.String() + } + + err = getDebugClient(cmd).Capture(ctx, debug.CaptureOptions{ + AccountID: args[0], + Duration: durationStr, + FilterExpr: filterExpr, + Text: wantText, + Verbose: verbose, + ASCII: ascii, + Output: out, + }) + if err != nil { + return err + } + + cmd.PrintErrln("\nCapture finished.") + return nil +} + +// captureOutputWriter returns the writer and cleanup function for capture output. +func captureOutputWriter(cmd *cobra.Command, outPath string) (out *os.File, cleanup func(), err error) { + if outPath != "" { + f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp") + if err != nil { + return nil, nil, fmt.Errorf("create output file: %w", err) + } + tmpPath := f.Name() + return f, func() { + if err := f.Close(); err != nil { + cmd.PrintErrf("close output file: %v\n", err) + } + if fi, err := os.Stat(tmpPath); err == nil && fi.Size() > 0 { + if err := os.Rename(tmpPath, outPath); err != nil { + cmd.PrintErrf("rename output file: %v\n", err) + } else { + cmd.PrintErrf("Wrote %s\n", outPath) + } + } else { + os.Remove(tmpPath) + } + }, nil + } + + return os.Stdout, func() { + // no cleanup needed for stdout + }, nil +} diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go index 01b0bc8e6..e01149522 100644 --- a/proxy/internal/debug/client.go +++ b/proxy/internal/debug/client.go @@ -310,6 +310,76 @@ func (c *Client) printError(data map[string]any) { } } +// CaptureOptions configures a capture request. +type CaptureOptions struct { + AccountID string + Duration string + FilterExpr string + Text bool + Verbose bool + ASCII bool + Output io.Writer +} + +// Capture streams a packet capture from the debug endpoint. The response body +// (pcap or text) is written directly to opts.Output until the server closes the +// connection or the context is cancelled. +func (c *Client) Capture(ctx context.Context, opts CaptureOptions) error { + if opts.AccountID == "" { + return fmt.Errorf("account ID is required") + } + if opts.Output == nil { + return fmt.Errorf("output writer is required") + } + + params := url.Values{} + if opts.Duration != "" { + params.Set("duration", opts.Duration) + } + if opts.FilterExpr != "" { + params.Set("filter", opts.FilterExpr) + } + if opts.Text { + params.Set("format", "text") + } + if opts.Verbose { + params.Set("verbose", "true") + } + if opts.ASCII { + params.Set("ascii", "true") + } + + path := fmt.Sprintf("/debug/clients/%s/capture", url.PathEscape(opts.AccountID)) + if len(params) > 0 { + path += "?" + params.Encode() + } + + fullURL := c.baseURL + path + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + + // Use a separate client without timeout since captures stream for their full duration. + httpClient := &http.Client{} + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("server error (%d): %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + _, err = io.Copy(opts.Output, resp.Body) + if err != nil && ctx.Err() != nil { + return nil + } + return err +} + func (c *Client) fetchAndPrint(ctx context.Context, path string, printer func(map[string]any)) error { data, raw, err := c.fetch(ctx, path) if err != nil { diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index c507cfad9..36299837b 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/proxy/internal/health" "github.com/netbirdio/netbird/proxy/internal/roundtrip" "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/util/capture" "github.com/netbirdio/netbird/version" ) @@ -174,6 +175,8 @@ func (h *Handler) handleClientRoutes(w http.ResponseWriter, r *http.Request, pat h.handleClientStart(w, r, accountID) case "stop": h.handleClientStop(w, r, accountID) + case "capture": + h.handleCapture(w, r, accountID) default: return false } @@ -632,6 +635,99 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou }) } +const maxCaptureDuration = 30 * time.Minute + +// handleCapture streams a pcap or text packet capture for the given client. +// +// Query params: +// +// duration: capture duration (0 or absent = max, capped at 30m) +// format: "text" for human-readable output (default: pcap) +// filter: BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443") +func (h *Handler) handleCapture(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { + client, ok := h.provider.GetClient(accountID) + if !ok { + http.Error(w, "client not found", http.StatusNotFound) + return + } + + duration := maxCaptureDuration + if durationStr := r.URL.Query().Get("duration"); durationStr != "" { + d, err := time.ParseDuration(durationStr) + if err != nil { + http.Error(w, "invalid duration: "+err.Error(), http.StatusBadRequest) + return + } + if d < 0 { + http.Error(w, "duration must not be negative", http.StatusBadRequest) + return + } + if d > 0 { + duration = min(d, maxCaptureDuration) + } + } + + var matcher capture.Matcher + if expr := r.URL.Query().Get("filter"); expr != "" { + var err error + matcher, err = capture.ParseFilter(expr) + if err != nil { + http.Error(w, "invalid filter: "+err.Error(), http.StatusBadRequest) + return + } + } + + wantText := r.URL.Query().Get("format") == "text" + verbose := r.URL.Query().Get("verbose") == "true" + ascii := r.URL.Query().Get("ascii") == "true" + + opts := capture.Options{Matcher: matcher, Verbose: verbose, ASCII: ascii} + if wantText { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + opts.TextOutput = w + } else { + w.Header().Set("Content-Type", "application/vnd.tcpdump.pcap") + w.Header().Set("Content-Disposition", + fmt.Sprintf("attachment; filename=capture-%s.pcap", accountID)) + opts.Output = w + } + + sess, err := capture.NewSession(opts) + if err != nil { + http.Error(w, "create capture session: "+err.Error(), http.StatusInternalServerError) + return + } + defer sess.Stop() + + if err := client.SetCapture(sess); err != nil { + http.Error(w, "set capture: "+err.Error(), http.StatusServiceUnavailable) + return + } + defer client.SetCapture(nil) //nolint:errcheck + + // Flush headers after setup succeeds so errors above can still set status codes. + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + + timer := time.NewTimer(duration) + defer timer.Stop() + + select { + case <-r.Context().Done(): + case <-timer.C: + } + + if err := client.SetCapture(nil); err != nil { + h.logger.Debugf("clear capture: %v", err) + } + sess.Stop() + + stats := sess.Stats() + h.logger.Infof("capture for %s finished: %d packets, %d bytes, %d dropped", + accountID, stats.Packets, stats.Bytes, stats.Dropped) +} + func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON bool) { if !wantJSON { http.Redirect(w, r, "/debug", http.StatusSeeOther) diff --git a/util/capture/afpacket_linux.go b/util/capture/afpacket_linux.go new file mode 100644 index 000000000..8c382dba2 --- /dev/null +++ b/util/capture/afpacket_linux.go @@ -0,0 +1,193 @@ +package capture + +import ( + "encoding/binary" + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" +) + +// htons converts a uint16 from host to network (big-endian) byte order. +func htons(v uint16) uint16 { + var buf [2]byte + binary.BigEndian.PutUint16(buf[:], v) + return binary.NativeEndian.Uint16(buf[:]) +} + +// AFPacketCapture reads raw packets from a network interface using an +// AF_PACKET socket. This is the kernel-mode fallback when FilteredDevice is +// not available (kernel WireGuard). Linux only. +// +// It implements device.PacketCapture so it can be set on a Session, but it +// drives its own read loop rather than being called from FilteredDevice. +// Call Start to begin and Stop to end. +type AFPacketCapture struct { + ifaceName string + sess *Session + fd int + mu sync.Mutex + stopped chan struct{} + started atomic.Bool + closed atomic.Bool +} + +// NewAFPacketCapture creates a capture bound to the given interface. +// The session receives packets via Offer. +func NewAFPacketCapture(ifaceName string, sess *Session) *AFPacketCapture { + return &AFPacketCapture{ + ifaceName: ifaceName, + sess: sess, + fd: -1, + stopped: make(chan struct{}), + } +} + +// Start opens the AF_PACKET socket and begins reading packets. +// Packets are fed to the session via Offer. Returns immediately; +// the read loop runs in a goroutine. +func (c *AFPacketCapture) Start() error { + if c.sess == nil { + return errors.New("nil capture session") + } + if c.started.Load() { + return errors.New("capture already started") + } + + iface, err := net.InterfaceByName(c.ifaceName) + if err != nil { + return fmt.Errorf("interface %s: %w", c.ifaceName, err) + } + + fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_DGRAM|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, int(htons(unix.ETH_P_ALL))) + if err != nil { + return fmt.Errorf("create AF_PACKET socket: %w", err) + } + + addr := &unix.SockaddrLinklayer{ + Protocol: htons(unix.ETH_P_ALL), + Ifindex: iface.Index, + } + if err := unix.Bind(fd, addr); err != nil { + unix.Close(fd) + return fmt.Errorf("bind to %s: %w", c.ifaceName, err) + } + + c.mu.Lock() + c.fd = fd + c.mu.Unlock() + + c.started.Store(true) + go c.readLoop(fd) + return nil +} + +// Stop closes the socket and waits for the read loop to exit. Idempotent. +func (c *AFPacketCapture) Stop() { + if !c.closed.CompareAndSwap(false, true) { + if c.started.Load() { + <-c.stopped + } + return + } + + c.mu.Lock() + fd := c.fd + c.fd = -1 + c.mu.Unlock() + + if fd >= 0 { + unix.Close(fd) + } + + if c.started.Load() { + <-c.stopped + } +} + +func (c *AFPacketCapture) readLoop(fd int) { + defer close(c.stopped) + + buf := make([]byte, 65536) + pollFds := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN}} + + for { + if c.closed.Load() { + return + } + + ok, err := c.pollOnce(pollFds) + if err != nil { + return + } + if !ok { + continue + } + + c.recvAndOffer(fd, buf) + } +} + +// pollOnce waits for data on the fd. Returns true if data is ready, false for timeout/retry. +// Returns an error to signal the loop should exit. +func (c *AFPacketCapture) pollOnce(pollFds []unix.PollFd) (bool, error) { + n, err := unix.Poll(pollFds, 200) + if err != nil { + if errors.Is(err, unix.EINTR) { + return false, nil + } + if c.closed.Load() { + return false, errors.New("closed") + } + log.Debugf("af_packet poll: %v", err) + return false, err + } + if n == 0 { + return false, nil + } + if pollFds[0].Revents&(unix.POLLERR|unix.POLLHUP|unix.POLLNVAL) != 0 { + return false, errors.New("fd error") + } + return true, nil +} + +func (c *AFPacketCapture) recvAndOffer(fd int, buf []byte) { + nr, from, err := unix.Recvfrom(fd, buf, 0) + if err != nil { + if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) { + return + } + if !c.closed.Load() { + log.Debugf("af_packet recvfrom: %v", err) + } + return + } + if nr < 1 { + return + } + + ver := buf[0] >> 4 + if ver != 4 && ver != 6 { + return + } + + // The kernel sets Pkttype on AF_PACKET sockets: + // PACKET_HOST(0) = addressed to us (inbound) + // PACKET_OUTGOING(4) = sent by us (outbound) + outbound := false + if sa, ok := from.(*unix.SockaddrLinklayer); ok { + outbound = sa.Pkttype == unix.PACKET_OUTGOING + } + c.sess.Offer(buf[:nr], outbound) +} + +// Offer satisfies device.PacketCapture but is unused: the AFPacketCapture +// drives its own read loop. This exists only so the type signature is +// compatible if someone tries to set it as a PacketCapture. +func (c *AFPacketCapture) Offer([]byte, bool) { + // unused: AFPacketCapture drives its own read loop +} diff --git a/util/capture/afpacket_stub.go b/util/capture/afpacket_stub.go new file mode 100644 index 000000000..bde368e88 --- /dev/null +++ b/util/capture/afpacket_stub.go @@ -0,0 +1,26 @@ +//go:build !linux + +package capture + +import "errors" + +// AFPacketCapture is not available on this platform. +type AFPacketCapture struct{} + +// NewAFPacketCapture returns nil on non-Linux platforms. +func NewAFPacketCapture(string, *Session) *AFPacketCapture { return nil } + +// Start returns an error on non-Linux platforms. +func (c *AFPacketCapture) Start() error { + return errors.New("AF_PACKET capture is only supported on Linux") +} + +// Stop is a no-op on non-Linux platforms. +func (c *AFPacketCapture) Stop() { + // no-op on non-Linux platforms +} + +// Offer is a no-op on non-Linux platforms. +func (c *AFPacketCapture) Offer([]byte, bool) { + // no-op on non-Linux platforms +} diff --git a/util/capture/capture.go b/util/capture/capture.go new file mode 100644 index 000000000..0d92a4311 --- /dev/null +++ b/util/capture/capture.go @@ -0,0 +1,59 @@ +// Package capture provides userspace packet capture in pcap format. +// +// It taps decrypted WireGuard packets flowing through the FilteredDevice and +// writes them as pcap (readable by tcpdump, tshark, Wireshark) or as +// human-readable one-line-per-packet text. +package capture + +import "io" + +// Direction indicates whether a packet is entering or leaving the host. +type Direction uint8 + +const ( + // Inbound is a packet arriving from the network (FilteredDevice.Write path). + Inbound Direction = iota + // Outbound is a packet leaving the host (FilteredDevice.Read path). + Outbound +) + +// String returns "IN" or "OUT". +func (d Direction) String() string { + if d == Outbound { + return "OUT" + } + return "IN" +} + +const ( + protoICMP = 1 + protoTCP = 6 + protoUDP = 17 + protoICMPv6 = 58 +) + +// Options configures a capture session. +type Options struct { + // Output receives pcap-formatted data. Nil disables pcap output. + Output io.Writer + // TextOutput receives human-readable packet summaries. Nil disables text output. + TextOutput io.Writer + // Matcher selects which packets to capture. Nil captures all. + // Use ParseFilter("host 10.0.0.1 and tcp") or &Filter{...}. + Matcher Matcher + // Verbose adds seq/ack, TTL, window, total length to text output. + Verbose bool + // ASCII dumps transport payload as printable ASCII after each packet line. + ASCII bool + // SnapLen is the maximum bytes captured per packet. 0 means 65535. + SnapLen uint32 + // BufSize is the internal channel buffer size. 0 means 256. + BufSize int +} + +// Stats reports capture session counters. +type Stats struct { + Packets int64 + Bytes int64 + Dropped int64 +} diff --git a/util/capture/filter.go b/util/capture/filter.go new file mode 100644 index 000000000..d463450b8 --- /dev/null +++ b/util/capture/filter.go @@ -0,0 +1,528 @@ +package capture + +import ( + "encoding/binary" + "fmt" + "net/netip" + "strconv" + "strings" +) + +// Matcher tests whether a raw packet should be captured. +type Matcher interface { + Match(data []byte) bool +} + +// Filter selects packets by flat AND'd criteria. Useful for structured APIs +// (query params, proto fields). Implements Matcher. +type Filter struct { + SrcIP netip.Addr + DstIP netip.Addr + Host netip.Addr + SrcPort uint16 + DstPort uint16 + Port uint16 + Proto uint8 +} + +// IsEmpty returns true if the filter has no criteria set. +func (f *Filter) IsEmpty() bool { + return !f.SrcIP.IsValid() && !f.DstIP.IsValid() && !f.Host.IsValid() && + f.SrcPort == 0 && f.DstPort == 0 && f.Port == 0 && f.Proto == 0 +} + +// Match implements Matcher. All non-zero fields must match (AND). +func (f *Filter) Match(data []byte) bool { + if f.IsEmpty() { + return true + } + info, ok := parsePacketInfo(data) + if !ok { + return false + } + if f.Host.IsValid() && info.srcIP != f.Host && info.dstIP != f.Host { + return false + } + if f.SrcIP.IsValid() && info.srcIP != f.SrcIP { + return false + } + if f.DstIP.IsValid() && info.dstIP != f.DstIP { + return false + } + if f.Proto != 0 && info.proto != f.Proto { + return false + } + if f.Port != 0 && info.srcPort != f.Port && info.dstPort != f.Port { + return false + } + if f.SrcPort != 0 && info.srcPort != f.SrcPort { + return false + } + if f.DstPort != 0 && info.dstPort != f.DstPort { + return false + } + return true +} + +// exprNode evaluates a filter condition against pre-parsed packet info. +type exprNode func(info *packetInfo) bool + +// exprMatcher wraps an expression tree. Parses the packet once, then walks the tree. +type exprMatcher struct { + root exprNode +} + +func (m *exprMatcher) Match(data []byte) bool { + info, ok := parsePacketInfo(data) + if !ok { + return false + } + return m.root(&info) +} + +func nodeAnd(a, b exprNode) exprNode { + return func(info *packetInfo) bool { return a(info) && b(info) } +} + +func nodeOr(a, b exprNode) exprNode { + return func(info *packetInfo) bool { return a(info) || b(info) } +} + +func nodeNot(n exprNode) exprNode { + return func(info *packetInfo) bool { return !n(info) } +} + +func nodeHost(addr netip.Addr) exprNode { + return func(info *packetInfo) bool { return info.srcIP == addr || info.dstIP == addr } +} + +func nodeSrcHost(addr netip.Addr) exprNode { + return func(info *packetInfo) bool { return info.srcIP == addr } +} + +func nodeDstHost(addr netip.Addr) exprNode { + return func(info *packetInfo) bool { return info.dstIP == addr } +} + +func nodePort(port uint16) exprNode { + return func(info *packetInfo) bool { return info.srcPort == port || info.dstPort == port } +} + +func nodeSrcPort(port uint16) exprNode { + return func(info *packetInfo) bool { return info.srcPort == port } +} + +func nodeDstPort(port uint16) exprNode { + return func(info *packetInfo) bool { return info.dstPort == port } +} + +func nodeProto(proto uint8) exprNode { + return func(info *packetInfo) bool { return info.proto == proto } +} + +func nodeFamily(family uint8) exprNode { + return func(info *packetInfo) bool { return info.family == family } +} + +func nodeNet(prefix netip.Prefix) exprNode { + return func(info *packetInfo) bool { return prefix.Contains(info.srcIP) || prefix.Contains(info.dstIP) } +} + +func nodeSrcNet(prefix netip.Prefix) exprNode { + return func(info *packetInfo) bool { return prefix.Contains(info.srcIP) } +} + +func nodeDstNet(prefix netip.Prefix) exprNode { + return func(info *packetInfo) bool { return prefix.Contains(info.dstIP) } +} + +// packetInfo holds parsed header fields for filtering and display. +type packetInfo struct { + family uint8 + srcIP netip.Addr + dstIP netip.Addr + proto uint8 + srcPort uint16 + dstPort uint16 + hdrLen int +} + +func parsePacketInfo(data []byte) (packetInfo, bool) { + if len(data) < 1 { + return packetInfo{}, false + } + switch data[0] >> 4 { + case 4: + return parseIPv4Info(data) + case 6: + return parseIPv6Info(data) + default: + return packetInfo{}, false + } +} + +func parseIPv4Info(data []byte) (packetInfo, bool) { + if len(data) < 20 { + return packetInfo{}, false + } + ihl := int(data[0]&0x0f) * 4 + if ihl < 20 || len(data) < ihl { + return packetInfo{}, false + } + info := packetInfo{ + family: 4, + srcIP: netip.AddrFrom4([4]byte{data[12], data[13], data[14], data[15]}), + dstIP: netip.AddrFrom4([4]byte{data[16], data[17], data[18], data[19]}), + proto: data[9], + hdrLen: ihl, + } + if (info.proto == protoTCP || info.proto == protoUDP) && len(data) >= ihl+4 { + info.srcPort = binary.BigEndian.Uint16(data[ihl:]) + info.dstPort = binary.BigEndian.Uint16(data[ihl+2:]) + } + return info, true +} + +// parseIPv6Info parses the fixed IPv6 header. It reads the Next Header field +// directly, so packets with extension headers (hop-by-hop, routing, fragment, +// etc.) will report the extension type as the protocol rather than the final +// transport protocol. This is acceptable for a debug capture tool. +func parseIPv6Info(data []byte) (packetInfo, bool) { + if len(data) < 40 { + return packetInfo{}, false + } + var src, dst [16]byte + copy(src[:], data[8:24]) + copy(dst[:], data[24:40]) + info := packetInfo{ + family: 6, + srcIP: netip.AddrFrom16(src), + dstIP: netip.AddrFrom16(dst), + proto: data[6], + hdrLen: 40, + } + if (info.proto == protoTCP || info.proto == protoUDP) && len(data) >= 44 { + info.srcPort = binary.BigEndian.Uint16(data[40:]) + info.dstPort = binary.BigEndian.Uint16(data[42:]) + } + return info, true +} + +// ParseFilter parses a BPF-like filter expression and returns a Matcher. +// Returns nil Matcher for an empty expression (match all). +// +// Grammar (mirrors common tcpdump BPF syntax): +// +// orExpr = andExpr ("or" andExpr)* +// andExpr = unary ("and" unary)* +// unary = "not" unary | "(" orExpr ")" | term +// +// term = "host" IP | "src" target | "dst" target +// | "port" NUM | "net" PREFIX +// | "tcp" | "udp" | "icmp" | "icmp6" +// | "ip" | "ip6" | "proto" NUM +// target = "host" IP | "port" NUM | "net" PREFIX | IP +// +// Examples: +// +// host 10.0.0.1 and tcp port 443 +// not port 22 +// (host 10.0.0.1 or host 10.0.0.2) and tcp +// ip6 and icmp6 +// net 10.0.0.0/24 +// src host 10.0.0.1 or dst port 80 +func ParseFilter(expr string) (Matcher, error) { + tokens := tokenize(expr) + if len(tokens) == 0 { + return nil, nil //nolint:nilnil // nil Matcher means "match all" + } + + p := &parser{tokens: tokens} + node, err := p.parseOr() + if err != nil { + return nil, err + } + if p.pos < len(p.tokens) { + return nil, fmt.Errorf("unexpected token %q at position %d", p.tokens[p.pos], p.pos) + } + return &exprMatcher{root: node}, nil +} + +func tokenize(expr string) []string { + expr = strings.TrimSpace(expr) + if expr == "" { + return nil + } + // Split on whitespace but keep parens as separate tokens. + var tokens []string + for _, field := range strings.Fields(expr) { + tokens = append(tokens, splitParens(field)...) + } + return tokens +} + +// splitParens splits "(foo)" into "(", "foo", ")". +func splitParens(s string) []string { + var out []string + for strings.HasPrefix(s, "(") { + out = append(out, "(") + s = s[1:] + } + var trail []string + for strings.HasSuffix(s, ")") { + trail = append(trail, ")") + s = s[:len(s)-1] + } + if s != "" { + out = append(out, s) + } + out = append(out, trail...) + return out +} + +type parser struct { + tokens []string + pos int +} + +func (p *parser) peek() string { + if p.pos >= len(p.tokens) { + return "" + } + return strings.ToLower(p.tokens[p.pos]) +} + +func (p *parser) next() string { + tok := p.peek() + if tok != "" { + p.pos++ + } + return tok +} + +func (p *parser) expect(tok string) error { + got := p.next() + if got != tok { + return fmt.Errorf("expected %q, got %q", tok, got) + } + return nil +} + +func (p *parser) parseOr() (exprNode, error) { + left, err := p.parseAnd() + if err != nil { + return nil, err + } + for p.peek() == "or" { + p.next() + right, err := p.parseAnd() + if err != nil { + return nil, err + } + left = nodeOr(left, right) + } + return left, nil +} + +func (p *parser) parseAnd() (exprNode, error) { + left, err := p.parseUnary() + if err != nil { + return nil, err + } + for { + tok := p.peek() + if tok == "and" { + p.next() + right, err := p.parseUnary() + if err != nil { + return nil, err + } + left = nodeAnd(left, right) + continue + } + // Implicit AND: two atoms without "and" between them. + // Only if the next token starts an atom (not "or", ")", or EOF). + if tok != "" && tok != "or" && tok != ")" { + right, err := p.parseUnary() + if err != nil { + return nil, err + } + left = nodeAnd(left, right) + continue + } + break + } + return left, nil +} + +func (p *parser) parseUnary() (exprNode, error) { + switch p.peek() { + case "not": + p.next() + inner, err := p.parseUnary() + if err != nil { + return nil, err + } + return nodeNot(inner), nil + case "(": + p.next() + inner, err := p.parseOr() + if err != nil { + return nil, err + } + if err := p.expect(")"); err != nil { + return nil, fmt.Errorf("unclosed parenthesis") + } + return inner, nil + default: + return p.parseAtom() + } +} + +func (p *parser) parseAtom() (exprNode, error) { + tok := p.next() + if tok == "" { + return nil, fmt.Errorf("unexpected end of expression") + } + + switch tok { + case "host": + addr, err := p.parseAddr() + if err != nil { + return nil, fmt.Errorf("host: %w", err) + } + return nodeHost(addr), nil + + case "port": + port, err := p.parsePort() + if err != nil { + return nil, fmt.Errorf("port: %w", err) + } + return nodePort(port), nil + + case "net": + prefix, err := p.parsePrefix() + if err != nil { + return nil, fmt.Errorf("net: %w", err) + } + return nodeNet(prefix), nil + + case "src": + return p.parseDirTarget(true) + + case "dst": + return p.parseDirTarget(false) + + case "tcp": + return nodeProto(protoTCP), nil + case "udp": + return nodeProto(protoUDP), nil + case "icmp": + return nodeProto(protoICMP), nil + case "icmp6": + return nodeProto(protoICMPv6), nil + case "ip": + return nodeFamily(4), nil + case "ip6": + return nodeFamily(6), nil + + case "proto": + raw := p.next() + if raw == "" { + return nil, fmt.Errorf("proto: missing number") + } + n, err := strconv.Atoi(raw) + if err != nil || n < 0 || n > 255 { + return nil, fmt.Errorf("proto: invalid number %q", raw) + } + return nodeProto(uint8(n)), nil + + default: + return nil, fmt.Errorf("unknown filter keyword %q", tok) + } +} + +func (p *parser) parseDirTarget(isSrc bool) (exprNode, error) { + tok := p.peek() + switch tok { + case "host": + p.next() + addr, err := p.parseAddr() + if err != nil { + return nil, err + } + if isSrc { + return nodeSrcHost(addr), nil + } + return nodeDstHost(addr), nil + + case "port": + p.next() + port, err := p.parsePort() + if err != nil { + return nil, err + } + if isSrc { + return nodeSrcPort(port), nil + } + return nodeDstPort(port), nil + + case "net": + p.next() + prefix, err := p.parsePrefix() + if err != nil { + return nil, err + } + if isSrc { + return nodeSrcNet(prefix), nil + } + return nodeDstNet(prefix), nil + + default: + // Try as bare IP: "src 10.0.0.1" + addr, err := p.parseAddr() + if err != nil { + return nil, fmt.Errorf("expected host, port, net, or IP after src/dst, got %q", tok) + } + if isSrc { + return nodeSrcHost(addr), nil + } + return nodeDstHost(addr), nil + } +} + +func (p *parser) parseAddr() (netip.Addr, error) { + raw := p.next() + if raw == "" { + return netip.Addr{}, fmt.Errorf("missing IP address") + } + addr, err := netip.ParseAddr(raw) + if err != nil { + return netip.Addr{}, fmt.Errorf("invalid IP %q", raw) + } + return addr.Unmap(), nil +} + +func (p *parser) parsePort() (uint16, error) { + raw := p.next() + if raw == "" { + return 0, fmt.Errorf("missing port number") + } + n, err := strconv.Atoi(raw) + if err != nil || n < 1 || n > 65535 { + return 0, fmt.Errorf("invalid port %q", raw) + } + return uint16(n), nil +} + +func (p *parser) parsePrefix() (netip.Prefix, error) { + raw := p.next() + if raw == "" { + return netip.Prefix{}, fmt.Errorf("missing network prefix") + } + prefix, err := netip.ParsePrefix(raw) + if err != nil { + return netip.Prefix{}, fmt.Errorf("invalid prefix %q", raw) + } + return prefix, nil +} diff --git a/util/capture/filter_test.go b/util/capture/filter_test.go new file mode 100644 index 000000000..d5fd17566 --- /dev/null +++ b/util/capture/filter_test.go @@ -0,0 +1,263 @@ +package capture + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// buildIPv4Packet creates a minimal IPv4+TCP/UDP packet for filter testing. +func buildIPv4Packet(t *testing.T, srcIP, dstIP netip.Addr, proto uint8, srcPort, dstPort uint16) []byte { + t.Helper() + + hdrLen := 20 + pkt := make([]byte, hdrLen+20) + pkt[0] = 0x45 + pkt[9] = proto + + src := srcIP.As4() + dst := dstIP.As4() + copy(pkt[12:16], src[:]) + copy(pkt[16:20], dst[:]) + + pkt[20] = byte(srcPort >> 8) + pkt[21] = byte(srcPort) + pkt[22] = byte(dstPort >> 8) + pkt[23] = byte(dstPort) + + return pkt +} + +// buildIPv6Packet creates a minimal IPv6+TCP/UDP packet for filter testing. +func buildIPv6Packet(t *testing.T, srcIP, dstIP netip.Addr, proto uint8, srcPort, dstPort uint16) []byte { + t.Helper() + + pkt := make([]byte, 44) // 40 header + 4 ports + pkt[0] = 0x60 // version 6 + pkt[6] = proto // next header + + src := srcIP.As16() + dst := dstIP.As16() + copy(pkt[8:24], src[:]) + copy(pkt[24:40], dst[:]) + + pkt[40] = byte(srcPort >> 8) + pkt[41] = byte(srcPort) + pkt[42] = byte(dstPort >> 8) + pkt[43] = byte(dstPort) + + return pkt +} + +// ---- Filter struct tests ---- + +func TestFilter_Empty(t *testing.T) { + f := Filter{} + assert.True(t, f.IsEmpty()) + assert.True(t, f.Match(buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443))) +} + +func TestFilter_Host(t *testing.T) { + f := Filter{Host: netip.MustParseAddr("10.0.0.1")} + assert.True(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), protoTCP, 1234, 80))) + assert.True(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.1"), protoTCP, 1234, 80))) + assert.False(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.3"), protoTCP, 1234, 80))) +} + +func TestFilter_InvalidPacket(t *testing.T) { + f := Filter{Host: netip.MustParseAddr("10.0.0.1")} + assert.False(t, f.Match(nil)) + assert.False(t, f.Match([]byte{})) + assert.False(t, f.Match([]byte{0x00})) +} + +func TestParsePacketInfo_IPv4(t *testing.T) { + pkt := buildIPv4Packet(t, netip.MustParseAddr("192.168.1.1"), netip.MustParseAddr("10.0.0.1"), protoTCP, 54321, 80) + info, ok := parsePacketInfo(pkt) + require.True(t, ok) + assert.Equal(t, uint8(4), info.family) + assert.Equal(t, netip.MustParseAddr("192.168.1.1"), info.srcIP) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), info.dstIP) + assert.Equal(t, uint8(protoTCP), info.proto) + assert.Equal(t, uint16(54321), info.srcPort) + assert.Equal(t, uint16(80), info.dstPort) +} + +func TestParsePacketInfo_IPv6(t *testing.T) { + pkt := buildIPv6Packet(t, netip.MustParseAddr("fd00::1"), netip.MustParseAddr("fd00::2"), protoUDP, 1234, 53) + info, ok := parsePacketInfo(pkt) + require.True(t, ok) + assert.Equal(t, uint8(6), info.family) + assert.Equal(t, netip.MustParseAddr("fd00::1"), info.srcIP) + assert.Equal(t, netip.MustParseAddr("fd00::2"), info.dstIP) + assert.Equal(t, uint8(protoUDP), info.proto) + assert.Equal(t, uint16(1234), info.srcPort) + assert.Equal(t, uint16(53), info.dstPort) +} + +// ---- ParseFilter expression tests ---- + +func matchV4(t *testing.T, m Matcher, srcIP, dstIP string, proto uint8, srcPort, dstPort uint16) bool { + t.Helper() + return m.Match(buildIPv4Packet(t, netip.MustParseAddr(srcIP), netip.MustParseAddr(dstIP), proto, srcPort, dstPort)) +} + +func matchV6(t *testing.T, m Matcher, srcIP, dstIP string, proto uint8, srcPort, dstPort uint16) bool { + t.Helper() + return m.Match(buildIPv6Packet(t, netip.MustParseAddr(srcIP), netip.MustParseAddr(dstIP), proto, srcPort, dstPort)) +} + +func TestParseFilter_Empty(t *testing.T) { + m, err := ParseFilter("") + require.NoError(t, err) + assert.Nil(t, m, "empty expression should return nil matcher") +} + +func TestParseFilter_Atoms(t *testing.T) { + tests := []struct { + expr string + match bool + }{ + {"tcp", true}, + {"udp", false}, + {"host 10.0.0.1", true}, + {"host 10.0.0.99", false}, + {"port 443", true}, + {"port 80", false}, + {"src host 10.0.0.1", true}, + {"dst host 10.0.0.2", true}, + {"dst host 10.0.0.1", false}, + {"src port 12345", true}, + {"dst port 443", true}, + {"dst port 80", false}, + {"proto 6", true}, + {"proto 17", false}, + } + + pkt := buildIPv4Packet(t, netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), protoTCP, 12345, 443) + + for _, tt := range tests { + t.Run(tt.expr, func(t *testing.T) { + m, err := ParseFilter(tt.expr) + require.NoError(t, err) + assert.Equal(t, tt.match, m.Match(pkt)) + }) + } +} + +func TestParseFilter_And(t *testing.T) { + m, err := ParseFilter("host 10.0.0.1 and tcp port 443") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 443)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoUDP, 55555, 443), "wrong proto") + assert.False(t, matchV4(t, m, "10.0.0.3", "10.0.0.2", protoTCP, 55555, 443), "wrong host") + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 80), "wrong port") +} + +func TestParseFilter_ImplicitAnd(t *testing.T) { + // "tcp port 443" = implicit AND between tcp and port 443 + m, err := ParseFilter("tcp port 443") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoUDP, 1, 443)) +} + +func TestParseFilter_Or(t *testing.T) { + m, err := ParseFilter("port 80 or port 443") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 80)) + assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 443)) + assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 8080)) +} + +func TestParseFilter_Not(t *testing.T) { + m, err := ParseFilter("not port 22") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 22)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 22, 80)) +} + +func TestParseFilter_Parens(t *testing.T) { + m, err := ParseFilter("(port 80 or port 443) and tcp") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 443)) + assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoUDP, 1, 443), "wrong proto") + assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 8080), "wrong port") +} + +func TestParseFilter_Family(t *testing.T) { + mV4, err := ParseFilter("ip") + require.NoError(t, err) + assert.True(t, matchV4(t, mV4, "10.0.0.1", "10.0.0.2", protoTCP, 1, 80)) + assert.False(t, matchV6(t, mV4, "fd00::1", "fd00::2", protoTCP, 1, 80)) + + mV6, err := ParseFilter("ip6") + require.NoError(t, err) + assert.False(t, matchV4(t, mV6, "10.0.0.1", "10.0.0.2", protoTCP, 1, 80)) + assert.True(t, matchV6(t, mV6, "fd00::1", "fd00::2", protoTCP, 1, 80)) +} + +func TestParseFilter_Net(t *testing.T) { + m, err := ParseFilter("net 10.0.0.0/24") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "192.168.1.1", protoTCP, 1, 80), "src in net") + assert.True(t, matchV4(t, m, "192.168.1.1", "10.0.0.200", protoTCP, 1, 80), "dst in net") + assert.False(t, matchV4(t, m, "10.0.1.1", "192.168.1.1", protoTCP, 1, 80), "neither in net") +} + +func TestParseFilter_SrcDstNet(t *testing.T) { + m, err := ParseFilter("src net 10.0.0.0/8 and dst net 192.168.0.0/16") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.1.2.3", "192.168.1.1", protoTCP, 1, 80)) + assert.False(t, matchV4(t, m, "192.168.1.1", "10.1.2.3", protoTCP, 1, 80), "reversed") +} + +func TestParseFilter_Complex(t *testing.T) { + // Real-world: capture HTTP(S) traffic to/from specific host, excluding SSH + m, err := ParseFilter("host 10.0.0.1 and (port 80 or port 443) and not port 22") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 443)) + assert.True(t, matchV4(t, m, "10.0.0.2", "10.0.0.1", protoTCP, 55555, 80)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 22, 443), "port 22 excluded") + assert.False(t, matchV4(t, m, "10.0.0.3", "10.0.0.2", protoTCP, 55555, 443), "wrong host") +} + +func TestParseFilter_IPv6Combined(t *testing.T) { + m, err := ParseFilter("ip6 and icmp6") + require.NoError(t, err) + assert.True(t, matchV6(t, m, "fd00::1", "fd00::2", protoICMPv6, 0, 0)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoICMP, 0, 0), "wrong family") + assert.False(t, matchV6(t, m, "fd00::1", "fd00::2", protoTCP, 1, 80), "wrong proto") +} + +func TestParseFilter_CaseInsensitive(t *testing.T) { + m, err := ParseFilter("HOST 10.0.0.1 AND TCP PORT 443") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443)) +} + +func TestParseFilter_Errors(t *testing.T) { + bad := []string{ + "badkeyword", + "host", + "port abc", + "port 99999", + "net invalid", + "(", + "(port 80", + "not", + "src", + } + for _, expr := range bad { + t.Run(expr, func(t *testing.T) { + _, err := ParseFilter(expr) + assert.Error(t, err, "should fail for %q", expr) + }) + } +} diff --git a/util/capture/pcap.go b/util/capture/pcap.go new file mode 100644 index 000000000..0a9057045 --- /dev/null +++ b/util/capture/pcap.go @@ -0,0 +1,85 @@ +package capture + +import ( + "encoding/binary" + "io" + "time" +) + +const ( + pcapMagic = 0xa1b2c3d4 + pcapVersionMaj = 2 + pcapVersionMin = 4 + // linkTypeRaw is LINKTYPE_RAW: raw IPv4/IPv6 packets without link-layer header. + linkTypeRaw = 101 + defaultSnapLen = 65535 +) + +// PcapWriter writes packets in pcap format to an underlying writer. +// The global header is written lazily on the first WritePacket call so that +// the writer can be used with unbuffered io.Pipes without deadlocking. +// It is not safe for concurrent use; callers must serialize access. +type PcapWriter struct { + w io.Writer + snapLen uint32 + headerWritten bool +} + +// NewPcapWriter creates a pcap writer. The global header is deferred until the +// first WritePacket call. +func NewPcapWriter(w io.Writer, snapLen uint32) *PcapWriter { + if snapLen == 0 { + snapLen = defaultSnapLen + } + return &PcapWriter{w: w, snapLen: snapLen} +} + +// writeGlobalHeader writes the 24-byte pcap file header. +func (pw *PcapWriter) writeGlobalHeader() error { + var hdr [24]byte + binary.LittleEndian.PutUint32(hdr[0:4], pcapMagic) + binary.LittleEndian.PutUint16(hdr[4:6], pcapVersionMaj) + binary.LittleEndian.PutUint16(hdr[6:8], pcapVersionMin) + binary.LittleEndian.PutUint32(hdr[16:20], pw.snapLen) + binary.LittleEndian.PutUint32(hdr[20:24], linkTypeRaw) + + _, err := pw.w.Write(hdr[:]) + return err +} + +// WriteHeader writes the pcap global header. Safe to call multiple times. +func (pw *PcapWriter) WriteHeader() error { + if pw.headerWritten { + return nil + } + if err := pw.writeGlobalHeader(); err != nil { + return err + } + pw.headerWritten = true + return nil +} + +// WritePacket writes a single packet record, preceded by the global header +// on the first call. +func (pw *PcapWriter) WritePacket(ts time.Time, data []byte) error { + if err := pw.WriteHeader(); err != nil { + return err + } + + origLen := uint32(len(data)) + if origLen > pw.snapLen { + data = data[:pw.snapLen] + } + + var hdr [16]byte + binary.LittleEndian.PutUint32(hdr[0:4], uint32(ts.Unix())) + binary.LittleEndian.PutUint32(hdr[4:8], uint32(ts.Nanosecond()/1000)) + binary.LittleEndian.PutUint32(hdr[8:12], uint32(len(data))) + binary.LittleEndian.PutUint32(hdr[12:16], origLen) + + if _, err := pw.w.Write(hdr[:]); err != nil { + return err + } + _, err := pw.w.Write(data) + return err +} diff --git a/util/capture/pcap_test.go b/util/capture/pcap_test.go new file mode 100644 index 000000000..c3d21ef4a --- /dev/null +++ b/util/capture/pcap_test.go @@ -0,0 +1,68 @@ +package capture + +import ( + "bytes" + "encoding/binary" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPcapWriter_GlobalHeader(t *testing.T) { + var buf bytes.Buffer + pw := NewPcapWriter(&buf, 0) + + // Header is lazy, so write a dummy packet to trigger it. + err := pw.WritePacket(time.Now(), []byte{0x45, 0, 0, 20, 0, 0, 0, 0, 64, 1, 0, 0, 10, 0, 0, 1, 10, 0, 0, 2}) + require.NoError(t, err) + + data := buf.Bytes() + require.GreaterOrEqual(t, len(data), 24, "should contain global header") + + assert.Equal(t, uint32(pcapMagic), binary.LittleEndian.Uint32(data[0:4]), "magic number") + assert.Equal(t, uint16(pcapVersionMaj), binary.LittleEndian.Uint16(data[4:6]), "version major") + assert.Equal(t, uint16(pcapVersionMin), binary.LittleEndian.Uint16(data[6:8]), "version minor") + assert.Equal(t, uint32(defaultSnapLen), binary.LittleEndian.Uint32(data[16:20]), "snap length") + assert.Equal(t, uint32(linkTypeRaw), binary.LittleEndian.Uint32(data[20:24]), "link type") +} + +func TestPcapWriter_WritePacket(t *testing.T) { + var buf bytes.Buffer + pw := NewPcapWriter(&buf, 100) + + ts := time.Date(2025, 6, 15, 12, 30, 45, 123456000, time.UTC) + payload := make([]byte, 50) + for i := range payload { + payload[i] = byte(i) + } + + err := pw.WritePacket(ts, payload) + require.NoError(t, err) + + data := buf.Bytes()[24:] // skip global header + require.Len(t, data, 16+50, "packet header + payload") + + assert.Equal(t, uint32(ts.Unix()), binary.LittleEndian.Uint32(data[0:4]), "timestamp seconds") + assert.Equal(t, uint32(123456), binary.LittleEndian.Uint32(data[4:8]), "timestamp microseconds") + assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[8:12]), "included length") + assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[12:16]), "original length") + assert.Equal(t, payload, data[16:], "packet data") +} + +func TestPcapWriter_SnapLen(t *testing.T) { + var buf bytes.Buffer + pw := NewPcapWriter(&buf, 10) + + ts := time.Now() + payload := make([]byte, 50) + + err := pw.WritePacket(ts, payload) + require.NoError(t, err) + + data := buf.Bytes()[24:] + assert.Equal(t, uint32(10), binary.LittleEndian.Uint32(data[8:12]), "included length should be truncated") + assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[12:16]), "original length preserved") + assert.Len(t, data[16:], 10, "only snap_len bytes written") +} diff --git a/util/capture/session.go b/util/capture/session.go new file mode 100644 index 000000000..09806e10c --- /dev/null +++ b/util/capture/session.go @@ -0,0 +1,213 @@ +package capture + +import ( + "fmt" + "sync" + "sync/atomic" + "time" +) + +const defaultBufSize = 256 + +type packetEntry struct { + ts time.Time + data []byte + dir Direction +} + +// Session manages an active packet capture. Packets are offered via Offer, +// buffered in a channel, and written to configured sinks by a background +// goroutine. This keeps the hot path (FilteredDevice.Read/Write) non-blocking. +// +// The caller must call Stop when done to flush remaining packets and release +// resources. +type Session struct { + pcapW *PcapWriter + textW *TextWriter + matcher Matcher + snapLen uint32 + flushFn func() + + ch chan packetEntry + done chan struct{} + stopped chan struct{} + + closeOnce sync.Once + closed atomic.Bool + packets atomic.Int64 + bytes atomic.Int64 + dropped atomic.Int64 + started time.Time +} + +// NewSession creates and starts a capture session. At least one of +// Options.Output or Options.TextOutput must be non-nil. +func NewSession(opts Options) (*Session, error) { + if opts.Output == nil && opts.TextOutput == nil { + return nil, fmt.Errorf("at least one output sink required") + } + + snapLen := opts.SnapLen + if snapLen == 0 { + snapLen = defaultSnapLen + } + + bufSize := opts.BufSize + if bufSize <= 0 { + bufSize = defaultBufSize + } + + s := &Session{ + matcher: opts.Matcher, + snapLen: snapLen, + ch: make(chan packetEntry, bufSize), + done: make(chan struct{}), + stopped: make(chan struct{}), + started: time.Now(), + } + + if opts.Output != nil { + s.pcapW = NewPcapWriter(opts.Output, snapLen) + } + if opts.TextOutput != nil { + s.textW = NewTextWriter(opts.TextOutput, opts.Verbose, opts.ASCII) + } + + s.flushFn = buildFlushFn(opts.Output, opts.TextOutput) + + go s.run() + return s, nil +} + +// Offer submits a packet for capture. It returns immediately and never blocks +// the caller. If the internal buffer is full the packet is dropped silently. +// +// outbound should be true for packets leaving the host (FilteredDevice.Read +// path) and false for packets arriving (FilteredDevice.Write path). +// +// Offer satisfies the device.PacketCapture interface. +func (s *Session) Offer(data []byte, outbound bool) { + if s.closed.Load() { + return + } + + if s.matcher != nil && !s.matcher.Match(data) { + return + } + + captureLen := len(data) + if s.snapLen > 0 && uint32(captureLen) > s.snapLen { + captureLen = int(s.snapLen) + } + + copied := make([]byte, captureLen) + copy(copied, data) + + dir := Inbound + if outbound { + dir = Outbound + } + + select { + case s.ch <- packetEntry{ts: time.Now(), data: copied, dir: dir}: + s.packets.Add(1) + s.bytes.Add(int64(len(data))) + default: + s.dropped.Add(1) + } +} + +// Stop signals the session to stop accepting packets, drains any buffered +// packets to the sinks, and waits for the writer goroutine to exit. +// It is safe to call multiple times. +func (s *Session) Stop() { + s.closeOnce.Do(func() { + s.closed.Store(true) + close(s.done) + }) + <-s.stopped +} + +// Done returns a channel that is closed when the session's writer goroutine +// has fully exited and all buffered packets have been flushed. +func (s *Session) Done() <-chan struct{} { + return s.stopped +} + +// Stats returns current capture counters. +func (s *Session) Stats() Stats { + return Stats{ + Packets: s.packets.Load(), + Bytes: s.bytes.Load(), + Dropped: s.dropped.Load(), + } +} + +func (s *Session) run() { + defer close(s.stopped) + + for { + select { + case pkt := <-s.ch: + s.write(pkt) + case <-s.done: + s.drain() + return + } + } +} + +func (s *Session) drain() { + for { + select { + case pkt := <-s.ch: + s.write(pkt) + default: + return + } + } +} + +func (s *Session) write(pkt packetEntry) { + if s.pcapW != nil { + // Best-effort: if the writer fails (broken pipe etc.), discard silently. + _ = s.pcapW.WritePacket(pkt.ts, pkt.data) + } + if s.textW != nil { + _ = s.textW.WritePacket(pkt.ts, pkt.data, pkt.dir) + } + s.flushFn() +} + +// buildFlushFn returns a function that flushes all writers that support it. +// This covers http.Flusher and similar streaming writers. +func buildFlushFn(writers ...any) func() { + type flusher interface { + Flush() + } + + var fns []func() + for _, w := range writers { + if w == nil { + continue + } + if f, ok := w.(flusher); ok { + fns = append(fns, f.Flush) + } + } + + switch len(fns) { + case 0: + return func() { + // no writers to flush + } + case 1: + return fns[0] + default: + return func() { + for _, fn := range fns { + fn() + } + } + } +} diff --git a/util/capture/session_test.go b/util/capture/session_test.go new file mode 100644 index 000000000..ab27686c6 --- /dev/null +++ b/util/capture/session_test.go @@ -0,0 +1,144 @@ +package capture + +import ( + "bytes" + "encoding/binary" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSession_PcapOutput(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{ + Output: &buf, + BufSize: 16, + }) + require.NoError(t, err) + + pkt := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443) + + sess.Offer(pkt, true) + sess.Stop() + + data := buf.Bytes() + require.Greater(t, len(data), 24, "should have global header + at least one packet") + + // Verify global header + assert.Equal(t, uint32(pcapMagic), binary.LittleEndian.Uint32(data[0:4])) + assert.Equal(t, uint32(linkTypeRaw), binary.LittleEndian.Uint32(data[20:24])) + + // Verify packet record + pktData := data[24:] + inclLen := binary.LittleEndian.Uint32(pktData[8:12]) + assert.Equal(t, uint32(len(pkt)), inclLen) + + stats := sess.Stats() + assert.Equal(t, int64(1), stats.Packets) + assert.Equal(t, int64(len(pkt)), stats.Bytes) + assert.Equal(t, int64(0), stats.Dropped) +} + +func TestSession_TextOutput(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{ + TextOutput: &buf, + BufSize: 16, + }) + require.NoError(t, err) + + pkt := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443) + + sess.Offer(pkt, false) + sess.Stop() + + output := buf.String() + assert.Contains(t, output, "TCP") + assert.Contains(t, output, "10.0.0.1") + assert.Contains(t, output, "10.0.0.2") + assert.Contains(t, output, "443") + assert.Contains(t, output, "[IN TCP]") +} + +func TestSession_Filter(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{ + Output: &buf, + Matcher: &Filter{Port: 443}, + }) + require.NoError(t, err) + + pktMatch := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443) + pktNoMatch := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 80) + + sess.Offer(pktMatch, true) + sess.Offer(pktNoMatch, true) + sess.Stop() + + stats := sess.Stats() + assert.Equal(t, int64(1), stats.Packets, "only matching packet should be captured") +} + +func TestSession_StopIdempotent(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{Output: &buf}) + require.NoError(t, err) + + sess.Stop() + sess.Stop() // should not panic or deadlock +} + +func TestSession_OfferAfterStop(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{Output: &buf}) + require.NoError(t, err) + sess.Stop() + + pkt := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443) + sess.Offer(pkt, true) // should not panic + + assert.Equal(t, int64(0), sess.Stats().Packets) +} + +func TestSession_Done(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{Output: &buf}) + require.NoError(t, err) + + select { + case <-sess.Done(): + t.Fatal("Done should not be closed before Stop") + default: + } + + sess.Stop() + + select { + case <-sess.Done(): + case <-time.After(time.Second): + t.Fatal("Done should be closed after Stop") + } +} + +func TestSession_RequiresOutput(t *testing.T) { + _, err := NewSession(Options{}) + assert.Error(t, err) +} diff --git a/util/capture/text.go b/util/capture/text.go new file mode 100644 index 000000000..b44bd0cad --- /dev/null +++ b/util/capture/text.go @@ -0,0 +1,638 @@ +package capture + +import ( + "encoding/binary" + "fmt" + "io" + "net/netip" + "strings" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +// TextWriter writes human-readable one-line-per-packet summaries. +// It is not safe for concurrent use; callers must serialize access. +type TextWriter struct { + w io.Writer + verbose bool + ascii bool + flows map[dirKey]uint32 +} + +type dirKey struct { + src netip.AddrPort + dst netip.AddrPort +} + +// NewTextWriter creates a text formatter that writes to w. +func NewTextWriter(w io.Writer, verbose, ascii bool) *TextWriter { + return &TextWriter{ + w: w, + verbose: verbose, + ascii: ascii, + flows: make(map[dirKey]uint32), + } +} + +// tag formats the fixed-width "[DIR PROTO]" prefix with right-aligned protocol. +func tag(dir Direction, proto string) string { + return fmt.Sprintf("[%-3s %4s]", dir, proto) +} + +// WritePacket formats and writes a single packet line. +func (tw *TextWriter) WritePacket(ts time.Time, data []byte, dir Direction) error { + ts = ts.Local() + info, ok := parsePacketInfo(data) + if !ok { + _, err := fmt.Fprintf(tw.w, "%s [%-3s ?] ??? len=%d\n", + ts.Format("15:04:05.000000"), dir, len(data)) + return err + } + + timeStr := ts.Format("15:04:05.000000") + + var err error + switch info.proto { + case protoTCP: + err = tw.writeTCP(timeStr, dir, &info, data) + case protoUDP: + err = tw.writeUDP(timeStr, dir, &info, data) + case protoICMP: + err = tw.writeICMPv4(timeStr, dir, &info, data) + case protoICMPv6: + err = tw.writeICMPv6(timeStr, dir, &info, data) + default: + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err = fmt.Fprintf(tw.w, "%s %s %s > %s length %d%s\n", + timeStr, tag(dir, fmt.Sprintf("P%d", info.proto)), + info.srcIP, info.dstIP, len(data)-info.hdrLen, verbose) + } + return err +} + +func (tw *TextWriter) writeTCP(timeStr string, dir Direction, info *packetInfo, data []byte) error { + tcp := &layers.TCP{} + if err := tcp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil { + return tw.writeFallback(timeStr, dir, "TCP", info, data) + } + + flags := tcpFlagsStr(tcp) + plen := len(tcp.Payload) + + // Protocol annotation + var annotation string + if plen > 0 { + annotation = annotatePayload(tcp.Payload) + } + + if !tw.verbose { + _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d [%s] length %d%s\n", + timeStr, tag(dir, "TCP"), + info.srcIP, info.srcPort, info.dstIP, info.dstPort, + flags, plen, annotation) + if err != nil { + return err + } + if tw.ascii && plen > 0 { + return tw.writeASCII(tcp.Payload) + } + return nil + } + + relSeq, relAck := tw.relativeSeqAck(info, tcp.Seq, tcp.Ack) + + var seqStr string + if plen > 0 { + seqStr = fmt.Sprintf(", seq %d:%d", relSeq, relSeq+uint32(plen)) + } else { + seqStr = fmt.Sprintf(", seq %d", relSeq) + } + + var ackStr string + if tcp.ACK { + ackStr = fmt.Sprintf(", ack %d", relAck) + } + + var opts string + if s := formatTCPOptions(tcp.Options); s != "" { + opts = ", options [" + s + "]" + } + + verbose := tw.verboseIP(data, info.family) + + _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d [%s]%s%s, win %d%s, length %d%s%s\n", + timeStr, tag(dir, "TCP"), + info.srcIP, info.srcPort, info.dstIP, info.dstPort, + flags, seqStr, ackStr, tcp.Window, opts, plen, annotation, verbose) + if err != nil { + return err + } + if tw.ascii && plen > 0 { + return tw.writeASCII(tcp.Payload) + } + return nil +} + +func (tw *TextWriter) writeUDP(timeStr string, dir Direction, info *packetInfo, data []byte) error { + udp := &layers.UDP{} + if err := udp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil { + return tw.writeFallback(timeStr, dir, "UDP", info, data) + } + + plen := len(udp.Payload) + + // DNS replaces the entire line format + if plen > 0 && isDNSPort(info.srcPort, info.dstPort) { + if s := formatDNSPayload(udp.Payload); s != "" { + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d %s%s\n", + timeStr, tag(dir, "UDP"), + info.srcIP, info.srcPort, info.dstIP, info.dstPort, + s, verbose) + return err + } + } + + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d length %d%s\n", + timeStr, tag(dir, "UDP"), + info.srcIP, info.srcPort, info.dstIP, info.dstPort, + plen, verbose) + if err != nil { + return err + } + if tw.ascii && plen > 0 { + return tw.writeASCII(udp.Payload) + } + return nil +} + +func (tw *TextWriter) writeICMPv4(timeStr string, dir Direction, info *packetInfo, data []byte) error { + icmp := &layers.ICMPv4{} + if err := icmp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil { + return tw.writeFallback(timeStr, dir, "ICMP", info, data) + } + + var detail string + if icmp.TypeCode.Type() == layers.ICMPv4TypeEchoRequest || icmp.TypeCode.Type() == layers.ICMPv4TypeEchoReply { + detail = fmt.Sprintf("%s, id %d, seq %d", icmp.TypeCode.String(), icmp.Id, icmp.Seq) + } else { + detail = icmp.TypeCode.String() + } + + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s %s, length %d%s\n", + timeStr, tag(dir, "ICMP"), info.srcIP, info.dstIP, detail, len(data)-info.hdrLen, verbose) + return err +} + +func (tw *TextWriter) writeICMPv6(timeStr string, dir Direction, info *packetInfo, data []byte) error { + icmp := &layers.ICMPv6{} + if err := icmp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil { + return tw.writeFallback(timeStr, dir, "ICMP", info, data) + } + + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s %s, length %d%s\n", + timeStr, tag(dir, "ICMP"), info.srcIP, info.dstIP, icmp.TypeCode.String(), len(data)-info.hdrLen, verbose) + return err +} + +func (tw *TextWriter) writeFallback(timeStr string, dir Direction, proto string, info *packetInfo, data []byte) error { + _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d length %d\n", + timeStr, tag(dir, proto), + info.srcIP, info.srcPort, info.dstIP, info.dstPort, + len(data)-info.hdrLen) + return err +} + +func (tw *TextWriter) verboseIP(data []byte, family uint8) string { + return fmt.Sprintf(", ttl %d, id %d, iplen %d", + ipTTL(data, family), ipID(data, family), len(data)) +} + +// relativeSeqAck returns seq/ack relative to the first seen value per direction. +func (tw *TextWriter) relativeSeqAck(info *packetInfo, seq, ack uint32) (relSeq, relAck uint32) { + fwd := dirKey{ + src: netip.AddrPortFrom(info.srcIP, info.srcPort), + dst: netip.AddrPortFrom(info.dstIP, info.dstPort), + } + rev := dirKey{ + src: netip.AddrPortFrom(info.dstIP, info.dstPort), + dst: netip.AddrPortFrom(info.srcIP, info.srcPort), + } + + if isn, ok := tw.flows[fwd]; ok { + relSeq = seq - isn + } else { + tw.flows[fwd] = seq + } + + if isn, ok := tw.flows[rev]; ok { + relAck = ack - isn + } else { + relAck = ack + } + + return relSeq, relAck +} + +// writeASCII prints payload bytes as printable ASCII. +func (tw *TextWriter) writeASCII(payload []byte) error { + if len(payload) == 0 { + return nil + } + buf := make([]byte, len(payload)) + for i, b := range payload { + switch { + case b >= 0x20 && b < 0x7f: + buf[i] = b + case b == '\n' || b == '\r' || b == '\t': + buf[i] = b + default: + buf[i] = '.' + } + } + _, err := fmt.Fprintf(tw.w, "%s\n", buf) + return err +} + +// --- TCP helpers --- + +func ipTTL(data []byte, family uint8) uint8 { + if family == 4 && len(data) > 8 { + return data[8] + } + if family == 6 && len(data) > 7 { + return data[7] + } + return 0 +} + +func ipID(data []byte, family uint8) uint16 { + if family == 4 && len(data) >= 6 { + return binary.BigEndian.Uint16(data[4:6]) + } + return 0 +} + +func tcpFlagsStr(tcp *layers.TCP) string { + var buf [6]byte + n := 0 + if tcp.SYN { + buf[n] = 'S' + n++ + } + if tcp.FIN { + buf[n] = 'F' + n++ + } + if tcp.RST { + buf[n] = 'R' + n++ + } + if tcp.PSH { + buf[n] = 'P' + n++ + } + if tcp.ACK { + buf[n] = '.' + n++ + } + if tcp.URG { + buf[n] = 'U' + n++ + } + if n == 0 { + return "none" + } + return string(buf[:n]) +} + +func formatTCPOptions(opts []layers.TCPOption) string { + var parts []string + for _, opt := range opts { + switch opt.OptionType { + case layers.TCPOptionKindEndList: + return strings.Join(parts, ",") + case layers.TCPOptionKindNop: + parts = append(parts, "nop") + case layers.TCPOptionKindMSS: + if len(opt.OptionData) == 2 { + parts = append(parts, fmt.Sprintf("mss %d", binary.BigEndian.Uint16(opt.OptionData))) + } + case layers.TCPOptionKindWindowScale: + if len(opt.OptionData) == 1 { + parts = append(parts, fmt.Sprintf("wscale %d", opt.OptionData[0])) + } + case layers.TCPOptionKindSACKPermitted: + parts = append(parts, "sackOK") + case layers.TCPOptionKindSACK: + blocks := len(opt.OptionData) / 8 + parts = append(parts, fmt.Sprintf("sack %d", blocks)) + case layers.TCPOptionKindTimestamps: + if len(opt.OptionData) == 8 { + tsval := binary.BigEndian.Uint32(opt.OptionData[0:4]) + tsecr := binary.BigEndian.Uint32(opt.OptionData[4:8]) + parts = append(parts, fmt.Sprintf("TS val %d ecr %d", tsval, tsecr)) + } + } + } + return strings.Join(parts, ",") +} + +// --- Protocol annotation --- + +// annotatePayload returns a protocol annotation string for known application protocols. +func annotatePayload(payload []byte) string { + if len(payload) < 4 { + return "" + } + + s := string(payload) + + // SSH banner: "SSH-2.0-OpenSSH_9.6\r\n" + if strings.HasPrefix(s, "SSH-") { + if end := strings.IndexByte(s, '\r'); end > 0 && end < 256 { + return ": " + s[:end] + } + } + + // TLS records + if ann := annotateTLS(payload); ann != "" { + return ": " + ann + } + + // HTTP request or response + for _, method := range [...]string{"GET ", "POST ", "PUT ", "DELETE ", "HEAD ", "PATCH ", "OPTIONS ", "CONNECT "} { + if strings.HasPrefix(s, method) { + if end := strings.IndexByte(s, '\r'); end > 0 && end < 200 { + return ": " + s[:end] + } + } + } + if strings.HasPrefix(s, "HTTP/") { + if end := strings.IndexByte(s, '\r'); end > 0 && end < 200 { + return ": " + s[:end] + } + } + + return "" +} + +// annotateTLS returns a description for TLS handshake and alert records. +func annotateTLS(data []byte) string { + if len(data) < 6 { + return "" + } + + switch data[0] { + case 0x16: + return annotateTLSHandshake(data) + case 0x15: + return annotateTLSAlert(data) + } + return "" +} + +func annotateTLSHandshake(data []byte) string { + if len(data) < 10 { + return "" + } + switch data[5] { + case 0x01: + if sni := extractSNI(data); sni != "" { + return "TLS ClientHello SNI=" + sni + } + return "TLS ClientHello" + case 0x02: + return "TLS ServerHello" + } + return "" +} + +func annotateTLSAlert(data []byte) string { + if len(data) < 7 { + return "" + } + severity := "warning" + if data[5] == 2 { + severity = "fatal" + } + return fmt.Sprintf("TLS Alert %s %s", severity, tlsAlertDesc(data[6])) +} + +func tlsAlertDesc(code byte) string { + switch code { + case 0: + return "close_notify" + case 10: + return "unexpected_message" + case 40: + return "handshake_failure" + case 42: + return "bad_certificate" + case 43: + return "unsupported_certificate" + case 44: + return "certificate_revoked" + case 45: + return "certificate_expired" + case 48: + return "unknown_ca" + case 49: + return "access_denied" + case 50: + return "decode_error" + case 70: + return "protocol_version" + case 80: + return "internal_error" + case 86: + return "inappropriate_fallback" + case 90: + return "user_canceled" + case 112: + return "unrecognized_name" + default: + return fmt.Sprintf("alert(%d)", code) + } +} + +// extractSNI parses a TLS ClientHello and returns the SNI server name. +func extractSNI(data []byte) string { + if len(data) < 6 || data[0] != 0x16 { + return "" + } + recordLen := int(binary.BigEndian.Uint16(data[3:5])) + handshake := data[5:] + if len(handshake) > recordLen { + handshake = handshake[:recordLen] + } + + if len(handshake) < 4 || handshake[0] != 0x01 { + return "" + } + hsLen := int(handshake[1])<<16 | int(handshake[2])<<8 | int(handshake[3]) + body := handshake[4:] + if len(body) > hsLen { + body = body[:hsLen] + } + + extPos := clientHelloExtensionsOffset(body) + if extPos < 0 { + return "" + } + return findSNIExtension(body, extPos) +} + +// clientHelloExtensionsOffset returns the byte offset where extensions begin +// within the ClientHello body, or -1 if the body is too short. +func clientHelloExtensionsOffset(body []byte) int { + if len(body) < 38 { + return -1 + } + pos := 34 + + if pos >= len(body) { + return -1 + } + pos += 1 + int(body[pos]) // session ID + + if pos+2 > len(body) { + return -1 + } + pos += 2 + int(binary.BigEndian.Uint16(body[pos:pos+2])) // cipher suites + + if pos >= len(body) { + return -1 + } + pos += 1 + int(body[pos]) // compression methods + + if pos+2 > len(body) { + return -1 + } + return pos +} + +func findSNIExtension(body []byte, pos int) string { + extLen := int(binary.BigEndian.Uint16(body[pos : pos+2])) + pos += 2 + extEnd := pos + extLen + if extEnd > len(body) { + extEnd = len(body) + } + + for pos+4 <= extEnd { + extType := binary.BigEndian.Uint16(body[pos : pos+2]) + eLen := int(binary.BigEndian.Uint16(body[pos+2 : pos+4])) + pos += 4 + if pos+eLen > extEnd { + break + } + if extType == 0 && eLen >= 5 { + nameLen := int(binary.BigEndian.Uint16(body[pos+3 : pos+5])) + if pos+5+nameLen <= extEnd { + return string(body[pos+5 : pos+5+nameLen]) + } + } + pos += eLen + } + return "" +} + +func isDNSPort(src, dst uint16) bool { + return src == 53 || dst == 53 || src == 5353 || dst == 5353 +} + +// formatDNSPayload parses DNS and returns a tcpdump-style summary. +func formatDNSPayload(payload []byte) string { + d := &layers.DNS{} + if err := d.DecodeFromBytes(payload, gopacket.NilDecodeFeedback); err != nil { + return "" + } + + rd := "" + if d.RD { + rd = "+" + } + + if !d.QR { + return formatDNSQuery(d, rd, len(payload)) + } + return formatDNSResponse(d, rd, len(payload)) +} + +func formatDNSQuery(d *layers.DNS, rd string, plen int) string { + if len(d.Questions) == 0 { + return fmt.Sprintf("%04x%s (%d)", d.ID, rd, plen) + } + q := d.Questions[0] + return fmt.Sprintf("%04x%s %s? %s. (%d)", d.ID, rd, q.Type, q.Name, plen) +} + +func formatDNSResponse(d *layers.DNS, rd string, plen int) string { + anCount := d.ANCount + nsCount := d.NSCount + arCount := d.ARCount + + if d.ResponseCode != layers.DNSResponseCodeNoErr { + return fmt.Sprintf("%04x %d/%d/%d %s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, plen) + } + + if anCount > 0 && len(d.Answers) > 0 { + rr := d.Answers[0] + if rdata := shortRData(&rr); rdata != "" { + return fmt.Sprintf("%04x %d/%d/%d %s %s (%d)", d.ID, anCount, nsCount, arCount, rr.Type, rdata, plen) + } + } + + return fmt.Sprintf("%04x %d/%d/%d (%d)", d.ID, anCount, nsCount, arCount, plen) +} + +func shortRData(rr *layers.DNSResourceRecord) string { + switch rr.Type { + case layers.DNSTypeA, layers.DNSTypeAAAA: + if rr.IP != nil { + return rr.IP.String() + } + case layers.DNSTypeCNAME: + if len(rr.CNAME) > 0 { + return string(rr.CNAME) + "." + } + case layers.DNSTypePTR: + if len(rr.PTR) > 0 { + return string(rr.PTR) + "." + } + case layers.DNSTypeNS: + if len(rr.NS) > 0 { + return string(rr.NS) + "." + } + case layers.DNSTypeMX: + return fmt.Sprintf("%d %s.", rr.MX.Preference, rr.MX.Name) + case layers.DNSTypeTXT: + if len(rr.TXTs) > 0 { + return fmt.Sprintf("%q", string(rr.TXTs[0])) + } + case layers.DNSTypeSRV: + return fmt.Sprintf("%d %d %d %s.", rr.SRV.Priority, rr.SRV.Weight, rr.SRV.Port, rr.SRV.Name) + } + return "" +}