Refactor embed capture API to StartCapture/StopCapture

This commit is contained in:
Viktor Liu
2026-04-15 19:50:31 +02:00
parent e58c29d4f9
commit b734534f3c
4 changed files with 134 additions and 93 deletions

65
client/embed/capture.go Normal file
View File

@@ -0,0 +1,65 @@
package embed
import (
"io"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/util/capture"
)
// CaptureOptions configures a packet capture session.
type CaptureOptions struct {
// Output receives pcap-formatted data. Nil disables pcap output.
Output io.Writer
// TextOutput receives human-readable packet summaries. Nil disables text output.
TextOutput io.Writer
// Filter is a BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443").
// Empty captures all packets.
Filter string
// Verbose adds seq/ack, TTL, window, and total length to text output.
Verbose bool
// ASCII dumps transport payload as printable ASCII after each packet line.
ASCII bool
}
// CaptureStats reports capture session counters.
type CaptureStats struct {
Packets int64
Bytes int64
Dropped int64
}
// CaptureSession represents an active packet capture. Call Stop to end the
// capture and flush buffered packets.
type CaptureSession struct {
sess *capture.Session
engine *internal.Engine
}
// Stop ends the capture, flushes remaining packets, and detaches from the device.
// Safe to call multiple times.
func (cs *CaptureSession) Stop() {
if cs.engine != nil {
_ = cs.engine.SetCapture(nil)
cs.engine = nil
}
if cs.sess != nil {
cs.sess.Stop()
}
}
// Stats returns current capture counters.
func (cs *CaptureSession) Stats() CaptureStats {
s := cs.sess.Stats()
return CaptureStats{
Packets: s.Packets,
Bytes: s.Bytes,
Dropped: s.Dropped,
}
}
// Done returns a channel that is closed when the capture's writer goroutine
// has fully exited and all buffered packets have been flushed.
func (cs *CaptureSession) Done() <-chan struct{} {
return cs.sess.Done()
}

View File

@@ -66,7 +66,7 @@ type Options struct {
PrivateKey string
// ManagementURL overrides the default management server URL
ManagementURL string
// PreSharedKey is the pre-shared key for the WireGuard interface
// PreSharedKey is the pre-shared key for the tunnel interface
PreSharedKey string
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
LogOutput io.Writer
@@ -82,9 +82,9 @@ type Options struct {
DisableClientRoutes bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
WireguardPort *int
// MTU is the MTU for the WireGuard interface.
// MTU is the MTU for the tunnel interface.
// Valid values are in the range 576..8192 bytes.
// If non-nil, this value overrides any value stored in the config file.
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
@@ -470,6 +470,52 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
}
// StartCapture begins capturing packets on this client's tunnel device.
// Only one capture can be active at a time; starting a new one stops the previous.
// Call StopCapture (or CaptureSession.Stop) to end it.
func (c *Client) StartCapture(opts CaptureOptions) (*CaptureSession, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
var matcher capture.Matcher
if opts.Filter != "" {
m, err := capture.ParseFilter(opts.Filter)
if err != nil {
return nil, fmt.Errorf("parse filter: %w", err)
}
matcher = m
}
sess, err := capture.NewSession(capture.Options{
Output: opts.Output,
TextOutput: opts.TextOutput,
Matcher: matcher,
Verbose: opts.Verbose,
ASCII: opts.ASCII,
})
if err != nil {
return nil, fmt.Errorf("create capture session: %w", err)
}
if err := engine.SetCapture(sess); err != nil {
sess.Stop()
return nil, fmt.Errorf("set capture: %w", err)
}
return &CaptureSession{sess: sess, engine: engine}, nil
}
// StopCapture stops the active capture session if one is running.
func (c *Client) StopCapture() error {
engine, err := c.getEngine()
if err != nil {
return err
}
return engine.SetCapture(nil)
}
// getEngine safely retrieves the engine from the client with proper locking.
// Returns ErrClientNotStarted if the client is not started.
// Returns ErrEngineNotStarted if the engine is not available.
@@ -490,22 +536,6 @@ func (c *Client) getEngine() (*internal.Engine, error) {
return engine, nil
}
// SetCapture sets or clears packet capture on this client's WireGuard device.
// Pass nil to disable capture.
func (c *Client) SetCapture(sess *capture.Session) error {
engine, err := c.getEngine()
if err != nil {
return err
}
// Explicit nil check to avoid wrapping a nil *Session in the
// device.PacketCapture interface, which would appear non-nil to
// FilteredDevice and cause a nil-pointer dereference in Offer.
if sess == nil {
return engine.SetCapture(nil)
}
return engine.SetCapture(sess)
}
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
engine, err := c.getEngine()
if err != nil {

View File

@@ -8,37 +8,29 @@ import (
"sync"
"syscall/js"
log "github.com/sirupsen/logrus"
netbird "github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/util/capture"
)
// Handle holds a running capture session and the embedded client reference
// so it can be stopped later.
// Handle holds a running capture session so it can be stopped later.
type Handle struct {
client *netbird.Client
sess *capture.Session
cs *netbird.CaptureSession
stopFn js.Func
stopped bool
}
// Stop ends the capture and returns stats.
func (h *Handle) Stop() capture.Stats {
func (h *Handle) Stop() netbird.CaptureStats {
if h.stopped {
return h.sess.Stats()
return h.cs.Stats()
}
h.stopped = true
h.stopFn.Release()
if err := h.client.SetCapture(nil); err != nil {
log.Debugf("clear capture: %v", err)
}
h.sess.Stop()
return h.sess.Stats()
h.cs.Stop()
return h.cs.Stats()
}
func statsToJS(s capture.Stats) js.Value {
func statsToJS(s netbird.CaptureStats) js.Value {
obj := js.Global().Get("Object").Call("create", js.Null())
obj.Set("packets", js.ValueOf(s.Packets))
obj.Set("bytes", js.ValueOf(s.Bytes))
@@ -70,13 +62,6 @@ func parseOpts(jsOpts js.Value) (filter string, verbose, ascii bool) {
return
}
func buildMatcher(filter string) (capture.Matcher, error) {
if filter == "" {
return nil, nil
}
return capture.ParseFilter(filter)
}
// Start creates a capture session and returns a JS interface for streaming text
// output. The returned object exposes:
//
@@ -87,16 +72,11 @@ func buildMatcher(filter string) (capture.Matcher, error) {
func Start(client *netbird.Client, jsOpts js.Value) (js.Value, error) {
filter, verbose, ascii := parseOpts(jsOpts)
matcher, err := buildMatcher(filter)
if err != nil {
return js.Undefined(), err
}
cb := &jsCallbackWriter{}
sess, err := capture.NewSession(capture.Options{
cs, err := client.StartCapture(netbird.CaptureOptions{
TextOutput: cb,
Matcher: matcher,
Filter: filter,
Verbose: verbose,
ASCII: ascii,
})
@@ -104,12 +84,7 @@ func Start(client *netbird.Client, jsOpts js.Value) (js.Value, error) {
return js.Undefined(), err
}
if err := client.SetCapture(sess); err != nil {
sess.Stop()
return js.Undefined(), err
}
handle := &Handle{client: client, sess: sess}
handle := &Handle{cs: cs}
iface := js.Global().Get("Object").Call("create", js.Null())
handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
@@ -127,16 +102,11 @@ func Start(client *netbird.Client, jsOpts js.Value) (js.Value, error) {
func StartConsole(client *netbird.Client, jsOpts js.Value) (*Handle, error) {
filter, verbose, ascii := parseOpts(jsOpts)
matcher, err := buildMatcher(filter)
if err != nil {
return nil, err
}
cb := &jsCallbackWriter{}
sess, err := capture.NewSession(capture.Options{
cs, err := client.StartCapture(netbird.CaptureOptions{
TextOutput: cb,
Matcher: matcher,
Filter: filter,
Verbose: verbose,
ASCII: ascii,
})
@@ -144,12 +114,7 @@ func StartConsole(client *netbird.Client, jsOpts js.Value) (*Handle, error) {
return nil, err
}
if err := client.SetCapture(sess); err != nil {
sess.Stop()
return nil, err
}
handle := &Handle{client: client, sess: sess}
handle := &Handle{cs: cs}
handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
return statsToJS(handle.Stop())
})

View File

@@ -24,7 +24,6 @@ import (
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/util/capture"
"github.com/netbirdio/netbird/version"
)
@@ -667,21 +666,12 @@ func (h *Handler) handleCapture(w http.ResponseWriter, r *http.Request, accountI
}
}
var matcher capture.Matcher
if expr := r.URL.Query().Get("filter"); expr != "" {
var err error
matcher, err = capture.ParseFilter(expr)
if err != nil {
http.Error(w, "invalid filter: "+err.Error(), http.StatusBadRequest)
return
}
}
filter := r.URL.Query().Get("filter")
wantText := r.URL.Query().Get("format") == "text"
verbose := r.URL.Query().Get("verbose") == "true"
ascii := r.URL.Query().Get("ascii") == "true"
opts := capture.Options{Matcher: matcher, Verbose: verbose, ASCII: ascii}
opts := nbembed.CaptureOptions{Filter: filter, Verbose: verbose, ASCII: ascii}
if wantText {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
opts.TextOutput = w
@@ -692,18 +682,12 @@ func (h *Handler) handleCapture(w http.ResponseWriter, r *http.Request, accountI
opts.Output = w
}
sess, err := capture.NewSession(opts)
cs, err := client.StartCapture(opts)
if err != nil {
http.Error(w, "create capture session: "+err.Error(), http.StatusInternalServerError)
http.Error(w, "start capture: "+err.Error(), http.StatusServiceUnavailable)
return
}
defer sess.Stop()
if err := client.SetCapture(sess); err != nil {
http.Error(w, "set capture: "+err.Error(), http.StatusServiceUnavailable)
return
}
defer client.SetCapture(nil) //nolint:errcheck
defer cs.Stop()
// Flush headers after setup succeeds so errors above can still set status codes.
if f, ok := w.(http.Flusher); ok {
@@ -718,12 +702,9 @@ func (h *Handler) handleCapture(w http.ResponseWriter, r *http.Request, accountI
case <-timer.C:
}
if err := client.SetCapture(nil); err != nil {
h.logger.Debugf("clear capture: %v", err)
}
sess.Stop()
cs.Stop()
stats := sess.Stats()
stats := cs.Stats()
h.logger.Infof("capture for %s finished: %d packets, %d bytes, %d dropped",
accountID, stats.Packets, stats.Bytes, stats.Dropped)
}