From b734534f3ce619a2b87e0f4e72e21d5e1e5d7df9 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 15 Apr 2026 19:50:31 +0200 Subject: [PATCH] Refactor embed capture API to StartCapture/StopCapture --- client/embed/capture.go | 65 +++++++++++++++++++++++ client/embed/embed.go | 68 ++++++++++++++++++------- client/wasm/internal/capture/capture.go | 61 +++++----------------- proxy/internal/debug/handler.go | 33 +++--------- 4 files changed, 134 insertions(+), 93 deletions(-) create mode 100644 client/embed/capture.go diff --git a/client/embed/capture.go b/client/embed/capture.go new file mode 100644 index 000000000..30f9b496f --- /dev/null +++ b/client/embed/capture.go @@ -0,0 +1,65 @@ +package embed + +import ( + "io" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/util/capture" +) + +// CaptureOptions configures a packet capture session. +type CaptureOptions struct { + // Output receives pcap-formatted data. Nil disables pcap output. + Output io.Writer + // TextOutput receives human-readable packet summaries. Nil disables text output. + TextOutput io.Writer + // Filter is a BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443"). + // Empty captures all packets. + Filter string + // Verbose adds seq/ack, TTL, window, and total length to text output. + Verbose bool + // ASCII dumps transport payload as printable ASCII after each packet line. + ASCII bool +} + +// CaptureStats reports capture session counters. +type CaptureStats struct { + Packets int64 + Bytes int64 + Dropped int64 +} + +// CaptureSession represents an active packet capture. Call Stop to end the +// capture and flush buffered packets. +type CaptureSession struct { + sess *capture.Session + engine *internal.Engine +} + +// Stop ends the capture, flushes remaining packets, and detaches from the device. +// Safe to call multiple times. +func (cs *CaptureSession) Stop() { + if cs.engine != nil { + _ = cs.engine.SetCapture(nil) + cs.engine = nil + } + if cs.sess != nil { + cs.sess.Stop() + } +} + +// Stats returns current capture counters. +func (cs *CaptureSession) Stats() CaptureStats { + s := cs.sess.Stats() + return CaptureStats{ + Packets: s.Packets, + Bytes: s.Bytes, + Dropped: s.Dropped, + } +} + +// Done returns a channel that is closed when the capture's writer goroutine +// has fully exited and all buffered packets have been flushed. +func (cs *CaptureSession) Done() <-chan struct{} { + return cs.sess.Done() +} diff --git a/client/embed/embed.go b/client/embed/embed.go index f9a22718e..baa1d94d6 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -66,7 +66,7 @@ type Options struct { PrivateKey string // ManagementURL overrides the default management server URL ManagementURL string - // PreSharedKey is the pre-shared key for the WireGuard interface + // PreSharedKey is the pre-shared key for the tunnel interface PreSharedKey string // LogOutput is the output destination for logs (defaults to os.Stderr if nil) LogOutput io.Writer @@ -82,9 +82,9 @@ type Options struct { DisableClientRoutes bool // BlockInbound blocks all inbound connections from peers BlockInbound bool - // WireguardPort is the port for the WireGuard interface. Use 0 for a random port. + // WireguardPort is the port for the tunnel interface. Use 0 for a random port. WireguardPort *int - // MTU is the MTU for the WireGuard interface. + // MTU is the MTU for the tunnel interface. // Valid values are in the range 576..8192 bytes. // If non-nil, this value overrides any value stored in the config file. // If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280. @@ -470,6 +470,52 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error { return sshcommon.VerifyHostKey(storedKey, key, peerAddress) } +// StartCapture begins capturing packets on this client's tunnel device. +// Only one capture can be active at a time; starting a new one stops the previous. +// Call StopCapture (or CaptureSession.Stop) to end it. +func (c *Client) StartCapture(opts CaptureOptions) (*CaptureSession, error) { + engine, err := c.getEngine() + if err != nil { + return nil, err + } + + var matcher capture.Matcher + if opts.Filter != "" { + m, err := capture.ParseFilter(opts.Filter) + if err != nil { + return nil, fmt.Errorf("parse filter: %w", err) + } + matcher = m + } + + sess, err := capture.NewSession(capture.Options{ + Output: opts.Output, + TextOutput: opts.TextOutput, + Matcher: matcher, + Verbose: opts.Verbose, + ASCII: opts.ASCII, + }) + if err != nil { + return nil, fmt.Errorf("create capture session: %w", err) + } + + if err := engine.SetCapture(sess); err != nil { + sess.Stop() + return nil, fmt.Errorf("set capture: %w", err) + } + + return &CaptureSession{sess: sess, engine: engine}, nil +} + +// StopCapture stops the active capture session if one is running. +func (c *Client) StopCapture() error { + engine, err := c.getEngine() + if err != nil { + return err + } + return engine.SetCapture(nil) +} + // getEngine safely retrieves the engine from the client with proper locking. // Returns ErrClientNotStarted if the client is not started. // Returns ErrEngineNotStarted if the engine is not available. @@ -490,22 +536,6 @@ func (c *Client) getEngine() (*internal.Engine, error) { return engine, nil } -// SetCapture sets or clears packet capture on this client's WireGuard device. -// Pass nil to disable capture. -func (c *Client) SetCapture(sess *capture.Session) error { - engine, err := c.getEngine() - if err != nil { - return err - } - // Explicit nil check to avoid wrapping a nil *Session in the - // device.PacketCapture interface, which would appear non-nil to - // FilteredDevice and cause a nil-pointer dereference in Offer. - if sess == nil { - return engine.SetCapture(nil) - } - return engine.SetCapture(sess) -} - func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) { engine, err := c.getEngine() if err != nil { diff --git a/client/wasm/internal/capture/capture.go b/client/wasm/internal/capture/capture.go index f6dc263f8..53e43c45e 100644 --- a/client/wasm/internal/capture/capture.go +++ b/client/wasm/internal/capture/capture.go @@ -8,37 +8,29 @@ import ( "sync" "syscall/js" - log "github.com/sirupsen/logrus" - netbird "github.com/netbirdio/netbird/client/embed" - "github.com/netbirdio/netbird/util/capture" ) -// Handle holds a running capture session and the embedded client reference -// so it can be stopped later. +// Handle holds a running capture session so it can be stopped later. type Handle struct { - client *netbird.Client - sess *capture.Session + cs *netbird.CaptureSession stopFn js.Func stopped bool } // Stop ends the capture and returns stats. -func (h *Handle) Stop() capture.Stats { +func (h *Handle) Stop() netbird.CaptureStats { if h.stopped { - return h.sess.Stats() + return h.cs.Stats() } h.stopped = true h.stopFn.Release() - if err := h.client.SetCapture(nil); err != nil { - log.Debugf("clear capture: %v", err) - } - h.sess.Stop() - return h.sess.Stats() + h.cs.Stop() + return h.cs.Stats() } -func statsToJS(s capture.Stats) js.Value { +func statsToJS(s netbird.CaptureStats) js.Value { obj := js.Global().Get("Object").Call("create", js.Null()) obj.Set("packets", js.ValueOf(s.Packets)) obj.Set("bytes", js.ValueOf(s.Bytes)) @@ -70,13 +62,6 @@ func parseOpts(jsOpts js.Value) (filter string, verbose, ascii bool) { return } -func buildMatcher(filter string) (capture.Matcher, error) { - if filter == "" { - return nil, nil - } - return capture.ParseFilter(filter) -} - // Start creates a capture session and returns a JS interface for streaming text // output. The returned object exposes: // @@ -87,16 +72,11 @@ func buildMatcher(filter string) (capture.Matcher, error) { func Start(client *netbird.Client, jsOpts js.Value) (js.Value, error) { filter, verbose, ascii := parseOpts(jsOpts) - matcher, err := buildMatcher(filter) - if err != nil { - return js.Undefined(), err - } - cb := &jsCallbackWriter{} - sess, err := capture.NewSession(capture.Options{ + cs, err := client.StartCapture(netbird.CaptureOptions{ TextOutput: cb, - Matcher: matcher, + Filter: filter, Verbose: verbose, ASCII: ascii, }) @@ -104,12 +84,7 @@ func Start(client *netbird.Client, jsOpts js.Value) (js.Value, error) { return js.Undefined(), err } - if err := client.SetCapture(sess); err != nil { - sess.Stop() - return js.Undefined(), err - } - - handle := &Handle{client: client, sess: sess} + handle := &Handle{cs: cs} iface := js.Global().Get("Object").Call("create", js.Null()) handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any { @@ -127,16 +102,11 @@ func Start(client *netbird.Client, jsOpts js.Value) (js.Value, error) { func StartConsole(client *netbird.Client, jsOpts js.Value) (*Handle, error) { filter, verbose, ascii := parseOpts(jsOpts) - matcher, err := buildMatcher(filter) - if err != nil { - return nil, err - } - cb := &jsCallbackWriter{} - sess, err := capture.NewSession(capture.Options{ + cs, err := client.StartCapture(netbird.CaptureOptions{ TextOutput: cb, - Matcher: matcher, + Filter: filter, Verbose: verbose, ASCII: ascii, }) @@ -144,12 +114,7 @@ func StartConsole(client *netbird.Client, jsOpts js.Value) (*Handle, error) { return nil, err } - if err := client.SetCapture(sess); err != nil { - sess.Stop() - return nil, err - } - - handle := &Handle{client: client, sess: sess} + handle := &Handle{cs: cs} handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any { return statsToJS(handle.Stop()) }) diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index 36299837b..6cd124554 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -24,7 +24,6 @@ import ( "github.com/netbirdio/netbird/proxy/internal/health" "github.com/netbirdio/netbird/proxy/internal/roundtrip" "github.com/netbirdio/netbird/proxy/internal/types" - "github.com/netbirdio/netbird/util/capture" "github.com/netbirdio/netbird/version" ) @@ -667,21 +666,12 @@ func (h *Handler) handleCapture(w http.ResponseWriter, r *http.Request, accountI } } - var matcher capture.Matcher - if expr := r.URL.Query().Get("filter"); expr != "" { - var err error - matcher, err = capture.ParseFilter(expr) - if err != nil { - http.Error(w, "invalid filter: "+err.Error(), http.StatusBadRequest) - return - } - } - + filter := r.URL.Query().Get("filter") wantText := r.URL.Query().Get("format") == "text" verbose := r.URL.Query().Get("verbose") == "true" ascii := r.URL.Query().Get("ascii") == "true" - opts := capture.Options{Matcher: matcher, Verbose: verbose, ASCII: ascii} + opts := nbembed.CaptureOptions{Filter: filter, Verbose: verbose, ASCII: ascii} if wantText { w.Header().Set("Content-Type", "text/plain; charset=utf-8") opts.TextOutput = w @@ -692,18 +682,12 @@ func (h *Handler) handleCapture(w http.ResponseWriter, r *http.Request, accountI opts.Output = w } - sess, err := capture.NewSession(opts) + cs, err := client.StartCapture(opts) if err != nil { - http.Error(w, "create capture session: "+err.Error(), http.StatusInternalServerError) + http.Error(w, "start capture: "+err.Error(), http.StatusServiceUnavailable) return } - defer sess.Stop() - - if err := client.SetCapture(sess); err != nil { - http.Error(w, "set capture: "+err.Error(), http.StatusServiceUnavailable) - return - } - defer client.SetCapture(nil) //nolint:errcheck + defer cs.Stop() // Flush headers after setup succeeds so errors above can still set status codes. if f, ok := w.(http.Flusher); ok { @@ -718,12 +702,9 @@ func (h *Handler) handleCapture(w http.ResponseWriter, r *http.Request, accountI case <-timer.C: } - if err := client.SetCapture(nil); err != nil { - h.logger.Debugf("clear capture: %v", err) - } - sess.Stop() + cs.Stop() - stats := sess.Stats() + stats := cs.Stats() h.logger.Infof("capture for %s finished: %d packets, %d bytes, %d dropped", accountID, stats.Packets, stats.Bytes, stats.Dropped) }