diff --git a/client/cmd/capture.go b/client/cmd/capture.go index 70522fdbb..95caaa5cd 100644 --- a/client/cmd/capture.go +++ b/client/cmd/capture.go @@ -10,10 +10,12 @@ import ( "strings" "syscall" + "github.com/hashicorp/go-multierror" "github.com/spf13/cobra" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/durationpb" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/util/capture" ) @@ -87,7 +89,6 @@ func runCapture(cmd *cobra.Command, args []string) error { if err != nil { return err } - defer cleanup() if req.TextOutput { cmd.PrintErrf("Capturing packets... Press Ctrl+C to stop.\n") @@ -95,7 +96,12 @@ func runCapture(cmd *cobra.Command, args []string) error { cmd.PrintErrf("Capturing packets (pcap)... Press Ctrl+C to stop.\n") } - return streamCapture(ctx, cmd, stream, out) + streamErr := streamCapture(ctx, cmd, stream, out) + cleanupErr := cleanup() + if streamErr != nil { + return streamErr + } + return cleanupErr } func buildCaptureRequest(cmd *cobra.Command, args []string) (*proto.StartCaptureRequest, error) { @@ -148,13 +154,12 @@ func streamCapture(ctx context.Context, cmd *cobra.Command, stream proto.DaemonS } } -// captureOutput returns the writer for capture data and a cleanup function. -func captureOutput(cmd *cobra.Command) (io.Writer, func(), error) { +// captureOutput returns the writer for capture data and a cleanup function +// that finalizes the file. Errors from the cleanup must be propagated. +func captureOutput(cmd *cobra.Command) (io.Writer, func() error, error) { outPath, _ := cmd.Flags().GetString("output") if outPath == "" { - return os.Stdout, func() { - // no cleanup needed for stdout - }, nil + return os.Stdout, func() error { return nil }, nil } f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp") @@ -162,19 +167,24 @@ func captureOutput(cmd *cobra.Command) (io.Writer, func(), error) { return nil, nil, fmt.Errorf("create output file: %w", err) } tmpPath := f.Name() - return f, func() { + return f, func() error { + var merr *multierror.Error if err := f.Close(); err != nil { - cmd.PrintErrf("close output file: %v\n", err) + merr = multierror.Append(merr, fmt.Errorf("close output file: %w", err)) } - if fi, err := os.Stat(tmpPath); err == nil && fi.Size() > 0 { - if err := os.Rename(tmpPath, outPath); err != nil { - cmd.PrintErrf("rename output file: %v\n", err) - } else { - cmd.PrintErrf("Wrote %s\n", outPath) + fi, statErr := os.Stat(tmpPath) + if statErr != nil || fi.Size() == 0 { + if rmErr := os.Remove(tmpPath); rmErr != nil && !os.IsNotExist(rmErr) { + merr = multierror.Append(merr, fmt.Errorf("remove empty output file: %w", rmErr)) } - } else { - os.Remove(tmpPath) + return nberrors.FormatErrorOrNil(merr) } + if err := os.Rename(tmpPath, outPath); err != nil { + merr = multierror.Append(merr, fmt.Errorf("rename output file: %w", err)) + return nberrors.FormatErrorOrNil(merr) + } + cmd.PrintErrf("Wrote %s\n", outPath) + return nberrors.FormatErrorOrNil(merr) }, nil } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index ebb11b2d4..3787e63a8 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -384,12 +384,13 @@ func (m *Manager) initForwarder() error { return fmt.Errorf("create forwarder: %w", err) } + m.forwarder.Store(forwarder) + + // Re-load after store: a concurrent SetPacketCapture may have seen forwarder as nil and only updated pendingCapture. if pc := m.pendingCapture.Load(); pc != nil { forwarder.SetCapture(*pc) } - m.forwarder.Store(forwarder) - log.Debug("forwarder initialized") return nil diff --git a/client/server/capture.go b/client/server/capture.go index 17dbf9fd6..68a7a4462 100644 --- a/client/server/capture.go +++ b/client/server/capture.go @@ -46,11 +46,6 @@ func (bc *bundleCapture) stop() { if bc.cancel != nil { bc.cancel() } - if bc.engine != nil { - if err := bc.engine.SetCapture(nil); err != nil { - log.Debugf("clear bundle capture: %v", err) - } - } if bc.sess != nil { bc.sess.Stop() } @@ -87,9 +82,8 @@ func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.Daemo "packet capture is disabled; reinstall or reconfigure the service with --enable-capture") } - engine, err := s.getCaptureEngine() - if err != nil { - return err + if d := req.GetDuration(); d != nil && d.AsDuration() < 0 { + return status.Error(codes.InvalidArgument, "duration must not be negative") } matcher, err := parseCaptureFilter(req) @@ -117,7 +111,15 @@ func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.Daemo return status.Errorf(codes.Internal, "create capture session: %v", err) } + engine, err := s.claimCapture(sess) + if err != nil { + sess.Stop() + pw.Close() + return err + } + if err := engine.SetCapture(sess); err != nil { + s.releaseCapture(sess) sess.Stop() pw.Close() return status.Errorf(codes.Internal, "set capture: %v", err) @@ -127,9 +129,7 @@ func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.Daemo // The client waits for this before printing the banner, so it must arrive // before any packet data. if err := stream.Send(&proto.CapturePacket{}); err != nil { - if clearErr := engine.SetCapture(nil); clearErr != nil { - log.Debugf("clear capture after send failure: %v", clearErr) - } + s.clearCaptureIfOwner(sess, engine) sess.Stop() pw.Close() return status.Errorf(codes.Internal, "send initial message: %v", err) @@ -137,16 +137,7 @@ func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.Daemo ctx := stream.Context() if d := req.GetDuration(); d != nil { - dur := d.AsDuration() - if dur < 0 { - if clearErr := engine.SetCapture(nil); clearErr != nil { - log.Debugf("clear capture: %v", clearErr) - } - sess.Stop() - pw.Close() - return status.Errorf(codes.InvalidArgument, "duration must not be negative") - } - if dur > 0 { + if dur := d.AsDuration(); dur > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, dur) defer cancel() @@ -155,9 +146,7 @@ func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.Daemo go func() { <-ctx.Done() - if err := engine.SetCapture(nil); err != nil { - log.Debugf("clear capture: %v", err) - } + s.clearCaptureIfOwner(sess, engine) sess.Stop() pw.Close() }() @@ -202,6 +191,10 @@ func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCap s.stopBundleCaptureLocked() s.cleanupBundleCapture() + if s.activeCapture != nil { + return nil, status.Error(codes.FailedPrecondition, "another capture is already running") + } + engine, err := s.getCaptureEngineLocked() if err != nil { // Not fatal: kernel mode or not connected. Log and return success @@ -234,6 +227,7 @@ func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCap log.Warnf("packet capture unavailable (no filtered device), skipping: %v", err) return &proto.StartBundleCaptureResponse{}, nil } + s.activeCapture = sess ctx, cancel := context.WithTimeout(context.Background(), timeout) bc := &bundleCapture{ @@ -243,13 +237,19 @@ func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCap cancel: cancel, } + s.bundleCapture = bc + go func() { <-ctx.Done() - bc.stop() + s.mutex.Lock() + if s.bundleCapture == bc { + s.stopBundleCaptureLocked() + } else { + bc.stop() + } + s.mutex.Unlock() log.Infof("bundle capture auto-stopped after timeout") }() - - s.bundleCapture = bc log.Infof("bundle capture started (timeout=%s, file=%s)", timeout, f.Name()) return &proto.StartBundleCaptureResponse{}, nil @@ -269,9 +269,16 @@ func (s *Server) stopBundleCaptureLocked() { if s.bundleCapture == nil { return } - s.bundleCapture.stop() + bc := s.bundleCapture + if bc.engine != nil && s.activeCapture == bc.sess { + if err := bc.engine.SetCapture(nil); err != nil { + log.Debugf("clear bundle capture: %v", err) + } + s.activeCapture = nil + } + bc.stop() - stats := s.bundleCapture.sess.Stats() + stats := bc.sess.Stats() log.Infof("bundle capture stopped: %d packets, %d bytes, %d dropped", stats.Packets, stats.Bytes, stats.Dropped) } @@ -303,6 +310,45 @@ func (s *Server) getCaptureEngine() (*internal.Engine, error) { return s.getCaptureEngineLocked() } +// claimCapture reserves the engine's capture slot for sess. Returns +// FailedPrecondition if another capture is already active. +func (s *Server) claimCapture(sess *capture.Session) (*internal.Engine, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.activeCapture != nil { + return nil, status.Error(codes.FailedPrecondition, "another capture is already running") + } + engine, err := s.getCaptureEngineLocked() + if err != nil { + return nil, err + } + s.activeCapture = sess + return engine, nil +} + +// releaseCapture clears the active-capture owner if it still matches sess. +func (s *Server) releaseCapture(sess *capture.Session) { + s.mutex.Lock() + defer s.mutex.Unlock() + if s.activeCapture == sess { + s.activeCapture = nil + } +} + +// clearCaptureIfOwner clears engine's capture slot only if sess still owns it. +func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Engine) { + s.mutex.Lock() + defer s.mutex.Unlock() + if s.activeCapture != sess { + return + } + if err := engine.SetCapture(nil); err != nil { + log.Debugf("clear capture: %v", err) + } + s.activeCapture = nil +} + func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) { if s.connectClient == nil { return nil, status.Error(codes.FailedPrecondition, "client not connected") diff --git a/util/capture/afpacket_linux.go b/util/capture/afpacket_linux.go index 8c382dba2..bf59e806a 100644 --- a/util/capture/afpacket_linux.go +++ b/util/capture/afpacket_linux.go @@ -54,17 +54,23 @@ func (c *AFPacketCapture) Start() error { if c.sess == nil { return errors.New("nil capture session") } - if c.started.Load() { + if !c.started.CompareAndSwap(false, true) { return errors.New("capture already started") } + if c.closed.Load() { + c.started.Store(false) + return errors.New("cannot restart stopped capture") + } iface, err := net.InterfaceByName(c.ifaceName) if err != nil { + c.started.Store(false) return fmt.Errorf("interface %s: %w", c.ifaceName, err) } fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_DGRAM|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, int(htons(unix.ETH_P_ALL))) if err != nil { + c.started.Store(false) return fmt.Errorf("create AF_PACKET socket: %w", err) } @@ -74,6 +80,7 @@ func (c *AFPacketCapture) Start() error { } if err := unix.Bind(fd, addr); err != nil { unix.Close(fd) + c.started.Store(false) return fmt.Errorf("bind to %s: %w", c.ifaceName, err) } @@ -81,7 +88,6 @@ func (c *AFPacketCapture) Start() error { c.fd = fd c.mu.Unlock() - c.started.Store(true) go c.readLoop(fd) return nil }